diff --git a/mmpretrain/models/multimodal/__init__.py b/mmpretrain/models/multimodal/__init__.py index 73645f0f..e6650cfe 100644 --- a/mmpretrain/models/multimodal/__init__.py +++ b/mmpretrain/models/multimodal/__init__.py @@ -11,13 +11,29 @@ if WITH_MULTIMODAL: from .minigpt4 import * # noqa: F401, F403 from .ofa import * # noqa: F401, F403 from .otter import * # noqa: F401, F403 + from .ram import * # noqa: F401, F403 else: from mmpretrain.registry import MODELS from mmpretrain.utils.dependency import register_multimodal_placeholder register_multimodal_placeholder([ - 'Blip2Caption', 'Blip2Retrieval', 'Blip2VQA', 'BlipCaption', - 'BlipNLVR', 'BlipRetrieval', 'BlipGrounding', 'BlipVQA', 'Flamingo', - 'OFA', 'ChineseCLIP', 'MiniGPT4', 'Llava', 'Otter', 'CLIP', - 'CLIPZeroShot' + 'Blip2Caption', + 'Blip2Retrieval', + 'Blip2VQA', + 'BlipCaption', + 'BlipNLVR', + 'BlipRetrieval', + 'BlipGrounding', + 'BlipVQA', + 'Flamingo', + 'OFA', + 'ChineseCLIP', + 'MiniGPT4', + 'Llava', + 'Otter', + 'CLIP', + 'CLIPZeroShot', + 'RAM', + 'RAMNormal', + 'RAMOpenset', ], MODELS) diff --git a/mmpretrain/models/multimodal/ram/__init__.py b/mmpretrain/models/multimodal/ram/__init__.py new file mode 100644 index 00000000..35619d88 --- /dev/null +++ b/mmpretrain/models/multimodal/ram/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .ram import RAM, RAMNormal, RAMOpenset + +__all__ = ['RAM', 'RAMNormal', 'RAMOpenset'] diff --git a/mmpretrain/models/multimodal/ram/bert.py b/mmpretrain/models/multimodal/ram/bert.py new file mode 100644 index 00000000..f54b2ce8 --- /dev/null +++ b/mmpretrain/models/multimodal/ram/bert.py @@ -0,0 +1,1197 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# Modify from: +# https://github.com/xinyu1205/recognize-anything/blob/main/ram/models/bert.py + +import math +from typing import Tuple + +import torch +import torch.utils.checkpoint +from torch import Tensor, device, nn +from torch.nn import CrossEntropyLoss +from transformers.activations import ACT2FN +from transformers.modeling_outputs import ( + BaseModelOutputWithPastAndCrossAttentions, + BaseModelOutputWithPoolingAndCrossAttentions, + CausalLMOutputWithCrossAttentions) +from transformers.modeling_utils import (PreTrainedModel, + apply_chunking_to_forward, + find_pruneable_heads_and_indices, + prune_linear_layer) +from transformers.models.bert.configuration_bert import BertConfig +from transformers.utils import logging + +logger = logging.get_logger(__name__) + + +class BertEmbeddings_nopos(nn.Module): + """Construct the embeddings from word and position embeddings.""" + + def __init__(self, config): + super().__init__() + self.word_embeddings = nn.Embedding( + config.vocab_size, + config.hidden_size, + padding_idx=config.pad_token_id) + # self.position_embeddings = nn.Embedding( + # config.max_position_embeddings, config.hidden_size) + '''self.LayerNorm is not snake-cased to stick with + TensorFlow model variable name and be able to load''' + # any TensorFlow checkpoint file + self.LayerNorm = nn.LayerNorm( + config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + # position_ids (1, len position emb) is contiguous + # in memory and exported when serialized + # self.register_buffer("position_ids", + # torch.arange(config.max_position_embeddings).expand((1, -1))) + # self.position_embedding_type = \ + # getattr(config, "position_embedding_type", "absolute") + + self.config = config + + def forward(self, + input_ids=None, + position_ids=None, + inputs_embeds=None, + past_key_values_length=0): + if input_ids is not None: + input_shape = input_ids.size() + else: + input_shape = inputs_embeds.size()[:-1] + + seq_length = input_shape[1] # noqa: F841 + + # if position_ids is None: + # position_ids = self.position_ids[:, \ + # past_key_values_length : seq_length + \ + # past_key_values_length] + + if inputs_embeds is None: + inputs_embeds = self.word_embeddings(input_ids) + + embeddings = inputs_embeds + + # if self.position_embedding_type == "absolute": + # position_embeddings = self.position_embeddings(position_ids) + # # print('add position_embeddings!!!!') + # embeddings += position_embeddings + embeddings = self.LayerNorm(embeddings) + embeddings = self.dropout(embeddings) + return embeddings + + +class BertEmbeddings(nn.Module): + """Construct the embeddings from word and position embeddings.""" + + def __init__(self, config): + super().__init__() + self.word_embeddings = nn.Embedding( + config.vocab_size, + config.hidden_size, + padding_idx=config.pad_token_id) + self.position_embeddings = nn.Embedding(config.max_position_embeddings, + config.hidden_size) + + # self.LayerNorm is not snake-cased to stick with + # TensorFlow model variable name and be able to load + # any TensorFlow checkpoint file + self.LayerNorm = nn.LayerNorm( + config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + # position_ids (1, len position emb) is contiguous + # in memory and exported when serialized + self.register_buffer( + 'position_ids', + torch.arange(config.max_position_embeddings).expand((1, -1))) + self.position_embedding_type = getattr(config, + 'position_embedding_type', + 'absolute') + + self.config = config + + def forward(self, + input_ids=None, + position_ids=None, + inputs_embeds=None, + past_key_values_length=0): + if input_ids is not None: + input_shape = input_ids.size() + else: + input_shape = inputs_embeds.size()[:-1] + + seq_length = input_shape[1] + + if position_ids is None: + position_ids = self.position_ids[:, past_key_values_length: + seq_length + + past_key_values_length] + + if inputs_embeds is None: + inputs_embeds = self.word_embeddings(input_ids) + + embeddings = inputs_embeds + + if self.position_embedding_type == 'absolute': + position_embeddings = self.position_embeddings(position_ids) + # print('add position_embeddings!!!!') + embeddings += position_embeddings + embeddings = self.LayerNorm(embeddings) + embeddings = self.dropout(embeddings) + return embeddings + + +class BertSelfAttention(nn.Module): + + def __init__(self, config, is_cross_attention): + super().__init__() + self.config = config + if config.hidden_size % config.num_attention_heads != 0 and \ + not hasattr(config, 'embedding_size'): + raise ValueError('''The hidden size (%d) is not a multiple of + the number of attention heads (%d)''' % + (config.hidden_size, config.num_attention_heads)) + + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = int(config.hidden_size / + config.num_attention_heads) + self.all_head_size = self.num_attention_heads * \ + self.attention_head_size + + self.query = nn.Linear(config.hidden_size, self.all_head_size) + if is_cross_attention: + self.key = nn.Linear(config.encoder_width, self.all_head_size) + self.value = nn.Linear(config.encoder_width, self.all_head_size) + else: + self.key = nn.Linear(config.hidden_size, self.all_head_size) + self.value = nn.Linear(config.hidden_size, self.all_head_size) + + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + self.position_embedding_type = getattr(config, + 'position_embedding_type', + 'absolute') + if (self.position_embedding_type == 'relative_key' + or self.position_embedding_type == 'relative_key_query'): + self.max_position_embeddings = config.max_position_embeddings + self.distance_embedding = nn.Embedding( + 2 * config.max_position_embeddings - 1, + self.attention_head_size) + self.save_attention = False + + def save_attn_gradients(self, attn_gradients): + self.attn_gradients = attn_gradients + + def get_attn_gradients(self): + return self.attn_gradients + + def save_attention_map(self, attention_map): + self.attention_map = attention_map + + def get_attention_map(self): + return self.attention_map + + def transpose_for_scores(self, x): + new_x_shape = x.size()[:-1] + (self.num_attention_heads, + self.attention_head_size) + x = x.view(*new_x_shape) + return x.permute(0, 2, 1, 3) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_value=None, + output_attentions=False, + ): + mixed_query_layer = self.query(hidden_states) + + # If this is instantiated as a cross-attention module, the keys + # and values come from an encoder; the attention mask needs to be + # such that the encoder's padding tokens are not attended to. + is_cross_attention = encoder_hidden_states is not None + + if is_cross_attention: + # print(self.key.weight.shape) + key_layer = self.transpose_for_scores( + self.key(encoder_hidden_states)) + value_layer = self.transpose_for_scores( + self.value(encoder_hidden_states)) + attention_mask = encoder_attention_mask + elif past_key_value is not None: + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + key_layer = torch.cat([past_key_value[0], key_layer], dim=2) + value_layer = torch.cat([past_key_value[1], value_layer], dim=2) + else: + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + + query_layer = self.transpose_for_scores(mixed_query_layer) + + past_key_value = (key_layer, value_layer) + + # compatible with higher versions of transformers + if key_layer.shape[0] > query_layer.shape[0]: + key_layer = key_layer[:query_layer.shape[0], :, :, :] + attention_mask = attention_mask[:query_layer.shape[0], :, :] + value_layer = value_layer[:query_layer.shape[0], :, :, :] + + # Take the dot product between "query" and "key" + # to get the raw attention scores. + attention_scores = torch.matmul(query_layer, + key_layer.transpose(-1, -2)) + + if (self.position_embedding_type == 'relative_key' + or self.position_embedding_type == 'relative_key_query'): + seq_length = hidden_states.size()[1] + position_ids_l = torch.arange( + seq_length, dtype=torch.long, + device=hidden_states.device).view(-1, 1) + position_ids_r = torch.arange( + seq_length, dtype=torch.long, + device=hidden_states.device).view(1, -1) + distance = position_ids_l - position_ids_r + positional_embedding = self.distance_embedding( + distance + self.max_position_embeddings - 1) + positional_embedding = positional_embedding.to( + dtype=query_layer.dtype) # fp16 compatibility + + if self.position_embedding_type == 'relative_key': + relative_position_scores = torch.einsum( + 'bhld,lrd->bhlr', query_layer, positional_embedding) + attention_scores = attention_scores + relative_position_scores + elif self.position_embedding_type == 'relative_key_query': + relative_position_scores_query = torch.einsum( + 'bhld,lrd->bhlr', query_layer, positional_embedding) + relative_position_scores_key = torch.einsum( + 'bhrd,lrd->bhlr', key_layer, positional_embedding) + attention_scores = attention_scores + \ + relative_position_scores_query + \ + relative_position_scores_key + + attention_scores = attention_scores / math.sqrt( + self.attention_head_size) + if attention_mask is not None: + # Apply the attention mask is (precomputed for + # all layers in BertModel forward() function) + attention_scores = attention_scores + attention_mask + + # Normalize the attention scores to probabilities. + attention_probs = nn.Softmax(dim=-1)(attention_scores) + + if is_cross_attention and self.save_attention: + self.save_attention_map(attention_probs) + attention_probs.register_hook(self.save_attn_gradients) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs_dropped = self.dropout(attention_probs) + + # Mask heads if we want to + if head_mask is not None: + attention_probs_dropped = attention_probs_dropped * head_mask + + context_layer = torch.matmul(attention_probs_dropped, value_layer) + + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + ( + self.all_head_size, ) + context_layer = context_layer.view(*new_context_layer_shape) + + outputs = (context_layer, + attention_probs) if output_attentions else (context_layer, ) + + outputs = outputs + (past_key_value, ) + return outputs + + +class BertSelfOutput(nn.Module): + + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.LayerNorm = nn.LayerNorm( + config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class BertAttention(nn.Module): + + def __init__(self, config, is_cross_attention=False): + super().__init__() + self.self = BertSelfAttention(config, is_cross_attention) + self.output = BertSelfOutput(config) + self.pruned_heads = set() + + def prune_heads(self, heads): + if len(heads) == 0: + return + heads, index = find_pruneable_heads_and_indices( + heads, self.self.num_attention_heads, + self.self.attention_head_size, self.pruned_heads) + + # Prune linear layers + self.self.query = prune_linear_layer(self.self.query, index) + self.self.key = prune_linear_layer(self.self.key, index) + self.self.value = prune_linear_layer(self.self.value, index) + self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) + + # Update hyper params and store pruned heads + self.self.num_attention_heads = self.self.num_attention_heads - len( + heads) + self.self.all_head_size = self.self.attention_head_size * \ + self.self.num_attention_heads + self.pruned_heads = self.pruned_heads.union(heads) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_value=None, + output_attentions=False, + ): + self_outputs = self.self( + hidden_states, + attention_mask, + head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + ) + attention_output = self.output(self_outputs[0], hidden_states) + outputs = (attention_output, + ) + self_outputs[1:] # add attentions if we output them + return outputs + + +class BertIntermediate(nn.Module): + + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.intermediate_size) + if isinstance(config.hidden_act, str): + self.intermediate_act_fn = ACT2FN[config.hidden_act] + else: + self.intermediate_act_fn = config.hidden_act + + def forward(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + return hidden_states + + +class BertOutput(nn.Module): + + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.intermediate_size, config.hidden_size) + self.LayerNorm = nn.LayerNorm( + config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class BertLayer(nn.Module): + + def __init__(self, config, layer_num): + super().__init__() + self.config = config + self.chunk_size_feed_forward = config.chunk_size_feed_forward + self.seq_len_dim = 1 + self.attention = BertAttention(config) + self.layer_num = layer_num + if self.config.add_cross_attention: + self.crossattention = BertAttention( + config, is_cross_attention=self.config.add_cross_attention) + self.intermediate = BertIntermediate(config) + self.output = BertOutput(config) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_value=None, + output_attentions=False, + mode=None, + ): + + if mode == 'tagging': + + assert encoder_hidden_states is not None, \ + '''encoder_hidden_states must be given + for cross-attention layers''' + + cross_attention_outputs = self.crossattention( + hidden_states, + attention_mask, + head_mask, + encoder_hidden_states, + encoder_attention_mask, + output_attentions=output_attentions, + ) + attention_output = cross_attention_outputs[0] + outputs = cross_attention_outputs[ + 1:-1] # add cross attentions if we output attention weights + + present_key_value = cross_attention_outputs[-1] + + else: + # decoder uni-directional self-attention + # cached key/values tuple is at positions 1,2 + self_attn_past_key_value = \ + (past_key_value[:2] + if past_key_value is not None else None) + self_attention_outputs = self.attention( + hidden_states, + attention_mask, + head_mask, + output_attentions=output_attentions, + past_key_value=self_attn_past_key_value, + ) + attention_output = self_attention_outputs[0] + + outputs = self_attention_outputs[1:-1] + present_key_value = self_attention_outputs[-1] + + if mode == 'multimodal': + assert encoder_hidden_states is not None, \ + '''encoder_hidden_states must be + given for cross-attention layers''' + + cross_attention_outputs = self.crossattention( + attention_output, + attention_mask, + head_mask, + encoder_hidden_states, + encoder_attention_mask, + output_attentions=output_attentions, + ) + attention_output = cross_attention_outputs[0] + outputs = outputs + cross_attention_outputs[ + 1: + -1] # add cross attentions if we output attention weights + layer_output = apply_chunking_to_forward(self.feed_forward_chunk, + self.chunk_size_feed_forward, + self.seq_len_dim, + attention_output) + outputs = (layer_output, ) + outputs + + outputs = outputs + (present_key_value, ) + + return outputs + + def feed_forward_chunk(self, attention_output): + intermediate_output = self.intermediate(attention_output) + layer_output = self.output(intermediate_output, attention_output) + return layer_output + + +class BertEncoder(nn.Module): + + def __init__(self, config): + super().__init__() + self.config = config + self.layer = nn.ModuleList( + [BertLayer(config, i) for i in range(config.num_hidden_layers)]) + self.gradient_checkpointing = False + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_values=None, + use_cache=None, + output_attentions=False, + output_hidden_states=False, + return_dict=True, + mode='multimodal', + ): + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + all_cross_attentions = ( + ) if output_attentions and self.config.add_cross_attention else None + + next_decoder_cache = () if use_cache else None + + for i in range(self.config.num_hidden_layers): + layer_module = self.layer[i] + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states, ) + + layer_head_mask = head_mask[i] if head_mask is not None else None + past_key_value = past_key_values[ + i] if past_key_values is not None else None + + if self.gradient_checkpointing and self.training: + + if use_cache: + logger.warn('''`use_cache=True` is incompatible with + gradient checkpointing. Setting `use_cache=False`...''' + ) + use_cache = False + + def create_custom_forward(module): + + def custom_forward(*inputs): + return module(*inputs, past_key_value, + output_attentions) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(layer_module), + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + mode=mode, + ) + else: + layer_outputs = layer_module( + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + mode=mode, + ) + + hidden_states = layer_outputs[0] + if use_cache: + next_decoder_cache += (layer_outputs[-1], ) + if output_attentions: + all_self_attentions = all_self_attentions + ( + layer_outputs[1], ) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states, ) + + if not return_dict: + return tuple(v for v in [ + hidden_states, + next_decoder_cache, + all_hidden_states, + all_self_attentions, + all_cross_attentions, + ] if v is not None) + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=next_decoder_cache, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + cross_attentions=all_cross_attentions, + ) + + +class BertPooler(nn.Module): + + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.activation = nn.Tanh() + + def forward(self, hidden_states): + # We "pool" the model by simply taking the hidden state corresponding + # to the first token. + first_token_tensor = hidden_states[:, 0] + pooled_output = self.dense(first_token_tensor) + pooled_output = self.activation(pooled_output) + return pooled_output + + +class BertPredictionHeadTransform(nn.Module): + + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + if isinstance(config.hidden_act, str): + self.transform_act_fn = ACT2FN[config.hidden_act] + else: + self.transform_act_fn = config.hidden_act + self.LayerNorm = nn.LayerNorm( + config.hidden_size, eps=config.layer_norm_eps) + + def forward(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = self.transform_act_fn(hidden_states) + hidden_states = self.LayerNorm(hidden_states) + return hidden_states + + +class BertLMPredictionHead(nn.Module): + + def __init__(self, config): + super().__init__() + self.transform = BertPredictionHeadTransform(config) + + # The output weights are the same as the input embeddings, but there is + # an output-only bias for each token. + self.decoder = nn.Linear( + config.hidden_size, config.vocab_size, bias=False) + + self.bias = nn.Parameter(torch.zeros(config.vocab_size)) + + # Need a link between the two variables so that + # the bias is correctly resized with `resize_token_embeddings` + self.decoder.bias = self.bias + + def forward(self, hidden_states): + hidden_states = self.transform(hidden_states) + hidden_states = self.decoder(hidden_states) + return hidden_states + + +class BertOnlyMLMHead(nn.Module): + + def __init__(self, config): + super().__init__() + self.predictions = BertLMPredictionHead(config) + + def forward(self, sequence_output): + prediction_scores = self.predictions(sequence_output) + return prediction_scores + + +class BertPreTrainedModel(PreTrainedModel): + """An abstract class to handle weights initialization and a simple + interface for downloading and loading pretrained models.""" + + config_class = BertConfig + base_model_prefix = 'bert' + _keys_to_ignore_on_load_missing = [r'position_ids'] + + def _init_weights(self, module): + """Initialize the weights.""" + if isinstance(module, (nn.Linear, nn.Embedding)): + # Slightly different from the TF version + # which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_( + mean=0.0, std=self.config.initializer_range) + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + if isinstance(module, nn.Linear) and module.bias is not None: + module.bias.data.zero_() + + +class BertModel(BertPreTrainedModel): + """The model can behave as an encoder (with only self-attention) as well as + a decoder, in which case a layer of cross-attention is added between the + self-attention layers, following the architecture described in `Attention + is all you need `__ by Ashish Vaswani, + Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. + + Gomez, Lukasz Kaiser and Illia Polosukhin. argument and + :obj:`add_cross_attention` set to :obj:`True`; an + :obj:`encoder_hidden_states` is then expected as an input to the forward + pass. + """ + + def __init__(self, config, add_pooling_layer=True): + super().__init__(config) + self.config = config + + self.embeddings = BertEmbeddings(config) + + self.encoder = BertEncoder(config) + + self.pooler = BertPooler(config) if add_pooling_layer else None + + self.init_weights() + + def get_input_embeddings(self): + return self.embeddings.word_embeddings + + def set_input_embeddings(self, value): + self.embeddings.word_embeddings = value + + def _prune_heads(self, heads_to_prune): + """Prunes heads of the model. + + heads_to_prune: + dict of {layer_num: list of heads to prune in this layer} + See base class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + def get_extended_attention_mask(self, attention_mask: Tensor, + input_shape: Tuple[int], device: device, + is_decoder: bool) -> Tensor: + """Makes broadcastable attention and causal masks so that future and + masked tokens are ignored. + + Arguments: + attention_mask (:obj:`torch.Tensor`): + Mask with ones indicating tokens to attend to, + zeros for tokens to ignore. + input_shape (:obj:`Tuple[int]`): + The shape of the input to the model. + device: (:obj:`torch.device`): + The device of the input to the model. + + Returns: + :obj:`torch.Tensor` The extended attention mask, + with a the same dtype as :obj:`attention_mask.dtype`. + """ + # We can provide a self-attention mask of dimensions + # [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it + # broadcastable to all heads. + if attention_mask.dim() == 3: + extended_attention_mask = attention_mask[:, None, :, :] + elif attention_mask.dim() == 2: + # Provided a padding mask of dimensions [batch_size, seq_length] + # - if the model is a decoder, apply a causal mask + # in addition to the padding mask + # - if the model is an encoder, make the mask + # broadcastable to [batch_size, num_heads, seq_length, seq_length] + if is_decoder: + batch_size, seq_length = input_shape + + seq_ids = torch.arange(seq_length, device=device) + causal_mask = seq_ids[None, None, :].repeat( + batch_size, seq_length, 1) <= seq_ids[None, :, None] + # in case past_key_values are used we need to + # add a prefix ones mask to the causal mask + # causal and attention masks must have same type + # with pytorch version < 1.3 + causal_mask = causal_mask.to(attention_mask.dtype) + + if causal_mask.shape[1] < attention_mask.shape[1]: + prefix_seq_len = attention_mask.shape[ + 1] - causal_mask.shape[1] + causal_mask = torch.cat( + [ + torch.ones( + (batch_size, seq_length, prefix_seq_len), + device=device, + dtype=causal_mask.dtype), + causal_mask, + ], + axis=-1, + ) + + extended_attention_mask = ( + causal_mask[:None, :, :] * + attention_mask[:, None, None, :]) + else: + extended_attention_mask = attention_mask[:, None, None, :] + else: + raise ValueError( + '''Wrong shape for input_ids (shape {}) or attention_mask + (shape {})'''.format(input_shape, attention_mask.shape)) + + # Since attention_mask is 1.0 + # for positions we want to attend and 0.0 + # for masked positions, this operation will + # create a tensor which is 0.0 for positions + # we want to attend and -10000.0 for masked positions. + # Since we are adding it to the raw scores + # before the softmax, this is effectively + # the same as removing these entirely. + extended_attention_mask = extended_attention_mask.to( + dtype=self.dtype) # fp16 compatibility + extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 + return extended_attention_mask + + def forward( + self, + input_ids=None, + attention_mask=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + encoder_embeds=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_values=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + is_decoder=False, + mode='multimodal', + ): + r""" + encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj: + `(batch_size, sequence_length, hidden_size)`, `optional`): + Sequence of hidden-states at the output of the last layer + of the encoder. Used in the cross-attention if + the model is configured as a decoder. + encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj: + `(batch_size, sequence_length)`, `optional`): + Mask to avoid performing attention on the padding token + indices of the encoder input. This mask is used in + the cross-attention if the model is configured as + a decoder. Mask values selected in ``[0, 1]``: + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length : + obj:`config.n_layers` with each tuple having 4 tensors of shape : + obj:`(batch_size, num_heads, sequence_length - 1, + embed_size_per_head)`): + Contains precomputed key and value hidden states of the + attention blocks. Can be used to speed up decoding. + If :obj:`past_key_values` are used, the user can optionally + input only the last :obj:`decoder_input_ids` + (those that don't have their past key value states given to + this model) of shape :obj:`(batch_size, 1)` + instead of all :obj:`decoder_input_ids` of shape :obj: + `(batch_size, sequence_length)`. + use_cache (:obj:`bool`, `optional`): + If set to :obj:`True`, :obj:`past_key_values` key value + states are returned and can be used to speed up + decoding (see :obj:`past_key_values`). + """ + output_attentions = ( + output_attentions if output_attentions is not None else + self.config.output_attentions) + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else + self.config.output_hidden_states) + return_dict = ( + return_dict + if return_dict is not None else self.config.use_return_dict) + + if is_decoder: + use_cache = ( + use_cache if use_cache is not None else self.config.use_cache) + else: + use_cache = False + + if input_ids is not None and inputs_embeds is not None: + raise ValueError('''You cannot specify both + input_ids and inputs_embeds at the same time''') + elif input_ids is not None: + input_shape = input_ids.size() + batch_size, seq_length = input_shape + device = input_ids.device + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + batch_size, seq_length = input_shape + device = inputs_embeds.device + elif encoder_embeds is not None: + input_shape = encoder_embeds.size()[:-1] + batch_size, seq_length = input_shape + device = encoder_embeds.device + else: + raise ValueError('''You have to specify either + input_ids or inputs_embeds or encoder_embeds''') + + # past_key_values_length + past_key_values_length = past_key_values[0][0].shape[ + 2] if past_key_values is not None else 0 + + if attention_mask is None: + attention_mask = torch.ones( + ((batch_size, seq_length + past_key_values_length)), + device=device) + + # We can provide a self-attention mask of dimensions + # [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to + # make it broadcastable to all heads. + extended_attention_mask: torch.Tensor = \ + (self.get_extended_attention_mask( + attention_mask, input_shape, device, is_decoder)) + + # If a 2D or 3D attention mask is provided for the cross-attention + # we need to make broadcastable to + # [batch_size, num_heads, seq_length, seq_length] + if encoder_hidden_states is not None: + if type(encoder_hidden_states) == list: + encoder_batch_size, encoder_sequence_length, _ = \ + (encoder_hidden_states[0].size()) + else: + encoder_batch_size, encoder_sequence_length, _ = \ + (encoder_hidden_states.size()) + encoder_hidden_shape = (encoder_batch_size, + encoder_sequence_length) + + if type(encoder_attention_mask) == list: + encoder_extended_attention_mask = [ + self.invert_attention_mask(mask) + for mask in encoder_attention_mask + ] + elif encoder_attention_mask is None: + encoder_attention_mask = torch.ones( + encoder_hidden_shape, device=device) + encoder_extended_attention_mask = self.invert_attention_mask( + encoder_attention_mask) + else: + encoder_extended_attention_mask = self.invert_attention_mask( + encoder_attention_mask) + else: + encoder_extended_attention_mask = None + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape + # [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape + # [num_hidden_layers x batch x num_heads x seq_length x seq_length] + head_mask = self.get_head_mask(head_mask, + self.config.num_hidden_layers) + + if encoder_embeds is None: + embedding_output = self.embeddings( + input_ids=input_ids, + position_ids=position_ids, + inputs_embeds=inputs_embeds, + past_key_values_length=past_key_values_length, + ) + else: + embedding_output = encoder_embeds + + encoder_outputs = self.encoder( + embedding_output, + attention_mask=extended_attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_extended_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + mode=mode, + ) + sequence_output = encoder_outputs[0] + pooled_output = self.pooler( + sequence_output) if self.pooler is not None else None + + if not return_dict: + return (sequence_output, pooled_output) + encoder_outputs[1:] + + return BaseModelOutputWithPoolingAndCrossAttentions( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + past_key_values=encoder_outputs.past_key_values, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + cross_attentions=encoder_outputs.cross_attentions, + ) + + +class BertLMHeadModel(BertPreTrainedModel): + + _keys_to_ignore_on_load_unexpected = [r'pooler'] + _keys_to_ignore_on_load_missing = [ + r'position_ids', r'predictions.decoder.bias' + ] + + def __init__(self, config): + super().__init__(config) + + self.bert = BertModel(config, add_pooling_layer=False) + self.cls = BertOnlyMLMHead(config) + + self.init_weights() + + def get_output_embeddings(self): + return self.cls.predictions.decoder + + def set_output_embeddings(self, new_embeddings): + self.cls.predictions.decoder = new_embeddings + + def forward( + self, + input_ids=None, + attention_mask=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + labels=None, + past_key_values=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + return_logits=False, + is_decoder=True, + reduction='mean', + mode='multimodal', + ): + r""" + encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj: + `(batch_size, sequence_length, hidden_size)`, `optional`): + Sequence of hidden-states at the output of the last layer + of the encoder. Used in the cross-attention if + the model is configured as a decoder. + encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj: + `(batch_size, sequence_length)`, `optional`): + Mask to avoid performing attention on the padding token + indices of the encoder input. This mask is used in + the cross-attention if the model is configured as a decoder. + Mask values selected in ``[0, 1]``: + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + labels (:obj:`torch.LongTensor` of shape :obj: + `(batch_size, sequence_length)`, `optional`): + Labels for computing the left-to-right + language modeling loss (next word prediction). + Indices should be in + ``[-100, 0, ..., config.vocab_size]`` + (see ``input_ids`` docstring) Tokens with indices set to + ``-100`` are ignored (masked), the loss is only computed + for the tokens with labels n ``[0, ..., config.vocab_size]`` + past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length + :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj: + `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden states of the attention + blocks. Can be used to speed up decoding. + If :obj:`past_key_values` are used, the user can optionally + input only the last :obj:`decoder_input_ids` + (those that don't have their past key value states given to + this model) of shape :obj:`(batch_size, 1)` + instead of all :obj:`decoder_input_ids` of shape :obj: + `(batch_size, sequence_length)`. + use_cache (:obj:`bool`, `optional`): + If set to :obj:`True`, :obj:`past_key_values` key value states + are returned and can be used to speed up + decoding (see :obj:`past_key_values`). + Returns: + Example:: + >>> from transformers import (BertTokenizer, + BertLMHeadModel, BertConfig) + >>> import torch + >>> tokenizer = BertTokenizer.from_pretrained('bert-base-cased') + >>> config = BertConfig.from_pretrained("bert-base-cased") + >>> model = BertLMHeadModel.from_pretrained( + 'bert-base-cased', config=config) + >>> inputs = tokenizer("Hello, my dog is cute", + return_tensors="pt") + >>> outputs = model(**inputs) + >>> prediction_logits = outputs.logits + """ + return_dict = ( + return_dict + if return_dict is not None else self.config.use_return_dict) + if labels is not None: + use_cache = False + + outputs = self.bert( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + is_decoder=is_decoder, + mode=mode, + ) + + sequence_output = outputs[0] + prediction_scores = self.cls(sequence_output) + # sequence_output.shape torch.Size([85, 30, 768]) + # prediction_scores.shape torch.Size([85, 30, 30524]) + # labels.shape torch.Size([85, 30]) + + if return_logits: + return prediction_scores[:, :-1, :].contiguous() + + lm_loss = None + if labels is not None: + # we are doing next-token prediction; shift + # prediction scores and input ids by one + shifted_prediction_scores = prediction_scores[:, : + -1, :].contiguous() + labels = labels[:, 1:].contiguous() + loss_fct = CrossEntropyLoss( + reduction=reduction, label_smoothing=0.1) + lm_loss = loss_fct( + shifted_prediction_scores.view(-1, self.config.vocab_size), + labels.view(-1)) + if reduction == 'none': + lm_loss = lm_loss.view(prediction_scores.size(0), -1).sum(1) + + if not return_dict: + output = (prediction_scores, ) + outputs[2:] + return ((lm_loss, ) + output) if lm_loss is not None else output + + return CausalLMOutputWithCrossAttentions( + loss=lm_loss, + logits=prediction_scores, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + cross_attentions=outputs.cross_attentions, + ) + + def prepare_inputs_for_generation(self, + input_ids, + past=None, + attention_mask=None, + **model_kwargs): + input_shape = input_ids.shape + # if model is used as a decoder in encoder-decoder model, + # the decoder attention mask is created on the fly + if attention_mask is None: + attention_mask = input_ids.new_ones(input_shape) + + # cut decoder_input_ids if past is used + if past is not None: + input_ids = input_ids[:, -1:] + + return { + 'input_ids': + input_ids, + 'attention_mask': + attention_mask, + 'past_key_values': + past, + 'encoder_hidden_states': + model_kwargs.get('encoder_hidden_states', None), + 'encoder_attention_mask': + model_kwargs.get('encoder_attention_mask', None), + 'is_decoder': + True, + } + + def _reorder_cache(self, past, beam_idx): + reordered_past = () + for layer_past in past: + reordered_past += (tuple( + past_state.index_select(0, beam_idx) + for past_state in layer_past), ) + return reordered_past diff --git a/mmpretrain/models/multimodal/ram/config/__init__.py b/mmpretrain/models/multimodal/ram/config/__init__.py new file mode 100644 index 00000000..ef101fec --- /dev/null +++ b/mmpretrain/models/multimodal/ram/config/__init__.py @@ -0,0 +1 @@ +# Copyright (c) OpenMMLab. All rights reserved. diff --git a/mmpretrain/models/multimodal/ram/config/ram_swin_large_14m.py b/mmpretrain/models/multimodal/ram/config/ram_swin_large_14m.py new file mode 100644 index 00000000..e4b88653 --- /dev/null +++ b/mmpretrain/models/multimodal/ram/config/ram_swin_large_14m.py @@ -0,0 +1,93 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# data settings +test_transforms_cfg = [ + dict(type='Resize', scale=(384, 384), interpolation='bicubic'), + dict( + type='mmpretrain.PackInputs', + algorithm_keys=['text'], + meta_keys=['image_id', 'scale_factor'], + ), +] + + +def get_ram_cfg(mode='normal'): + assert mode in ['normal', 'openset'], 'mode must "normal" or "openset"' + model_type = 'RAMNormal' if mode == 'normal' else 'RAMOpenset' + model_cfg = dict( + type=model_type, + tokenizer=dict( + type='BertTokenizer', + name_or_path='/public/DATA/qbw/ckpt/bert-base-uncased', + use_fast=False), + vision_backbone=dict( + type='SwinTransformer', + arch='large', + img_size=384, + window_size=12, + ), + tag_encoder={ + 'architectures': ['BertModel'], + 'attention_probs_dropout_prob': 0.1, + 'hidden_act': 'gelu', + 'hidden_dropout_prob': 0.1, + 'hidden_size': 768, + 'initializer_range': 0.02, + 'intermediate_size': 3072, + 'layer_norm_eps': 1e-12, + 'max_position_embeddings': 512, + 'model_type': 'bert', + 'num_attention_heads': 12, + 'num_hidden_layers': 12, + 'pad_token_id': 0, + 'type_vocab_size': 2, + 'vocab_size': 30524, + 'encoder_width': 512, + 'add_cross_attention': True + }, + text_decoder={ + 'architectures': ['BertModel'], + 'attention_probs_dropout_prob': 0.1, + 'hidden_act': 'gelu', + 'hidden_dropout_prob': 0.1, + 'hidden_size': 768, + 'initializer_range': 0.02, + 'intermediate_size': 3072, + 'layer_norm_eps': 1e-12, + 'max_position_embeddings': 512, + 'model_type': 'bert', + 'num_attention_heads': 12, + 'num_hidden_layers': 12, + 'pad_token_id': 0, + 'type_vocab_size': 2, + 'vocab_size': 30524, + 'encoder_width': 768, + 'add_cross_attention': True + }, + tagging_head={ + 'architectures': ['BertModel'], + 'attention_probs_dropout_prob': 0.1, + 'hidden_act': 'gelu', + 'hidden_dropout_prob': 0.1, + 'hidden_size': 768, + 'initializer_range': 0.02, + 'intermediate_size': 3072, + 'layer_norm_eps': 1e-12, + 'max_position_embeddings': 512, + 'model_type': 'bert', + 'num_attention_heads': 4, + 'num_hidden_layers': 2, + 'pad_token_id': 0, + 'type_vocab_size': 2, + 'vocab_size': 30522, + 'encoder_width': 512, + 'add_cross_attention': True, + 'add_tag_cross_attention': False + }, + data_preprocessor=dict( + type='MultiModalDataPreprocessor', + mean=[0.48145466 * 255, 0.4578275 * 255, 0.40821073 * 255], + std=[0.26862954 * 255, 0.26130258 * 255, 0.27577711 * 255], + to_rgb=False, + ), + ) + return model_cfg diff --git a/mmpretrain/models/multimodal/ram/data/ram_tag_list.pickle b/mmpretrain/models/multimodal/ram/data/ram_tag_list.pickle new file mode 100644 index 00000000..0519d1ee Binary files /dev/null and b/mmpretrain/models/multimodal/ram/data/ram_tag_list.pickle differ diff --git a/mmpretrain/models/multimodal/ram/data/ram_tag_list_chinese.pickle b/mmpretrain/models/multimodal/ram/data/ram_tag_list_chinese.pickle new file mode 100644 index 00000000..4abe105e Binary files /dev/null and b/mmpretrain/models/multimodal/ram/data/ram_tag_list_chinese.pickle differ diff --git a/mmpretrain/models/multimodal/ram/data/ram_tag_list_threshold.pickle b/mmpretrain/models/multimodal/ram/data/ram_tag_list_threshold.pickle new file mode 100644 index 00000000..2be681d6 Binary files /dev/null and b/mmpretrain/models/multimodal/ram/data/ram_tag_list_threshold.pickle differ diff --git a/mmpretrain/models/multimodal/ram/gradio_demo.py b/mmpretrain/models/multimodal/ram/gradio_demo.py new file mode 100644 index 00000000..206e6b40 --- /dev/null +++ b/mmpretrain/models/multimodal/ram/gradio_demo.py @@ -0,0 +1,109 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import argparse + +import gradio as gr +import torch + +from mmpretrain.registry import MODELS, TRANSFORMS +from .config.ram_swin_large_14m import get_ram_cfg, test_transforms_cfg +from .run.inference import inference + +parser = argparse.ArgumentParser( + description='RAM(Recognize Anything Model) demo') +parser.add_argument( + 'ram_ckpt', type=str, help='pretrained file for ram (absolute path)') +parser.add_argument( + 'clip_ckpt', + type=str, + help='clip vit-base-p16 pretrained file (absolute path)') +args = parser.parse_args() + +if torch.cuda.is_available(): + devices = [ + torch.device(f'cuda:{i}') for i in range(torch.cuda.device_count()) + ] +elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): + devices = [torch.device('mps')] +else: + devices = [torch.device('cpu')] + + +def get_free_device(): + if hasattr(torch.cuda, 'mem_get_info'): + free = [torch.cuda.mem_get_info(gpu)[0] for gpu in devices] + select = max(zip(free, range(len(free))))[1] + else: + import random + select = random.randint(0, len(devices) - 1) + return devices[select] + + +device = get_free_device() + + +def ram_inference(image, tag_list, mode, threshold): + test_transforms = TRANSFORMS.get('Compose')(transforms=test_transforms_cfg) + model = MODELS.build(get_ram_cfg(mode=mode)) + model.load_state_dict(torch.load(args.ram_ckpt)) + model.device = device + + if mode == 'openset': + categories = tag_list + if categories != '': + categories = categories.strip().split() + else: + categories = None + model.set_openset( + categories=categories, + clip_ckpt=args.clip_ckpt, + threshold=threshold) + + sample = dict(img=image) + result = inference(sample, model, test_transforms, mode=mode) + tag, tag_chinese, logits = \ + result.get('tag_output')[0][0], result.get('tag_output')[1][0],\ + result.get('logits_output')[0] + + def wrap(tags, logits): + if tags is None: + return 'Openset mode has no tag_en' + tag_lst = tags.split('|') + rt_lst = [] + for i, tag in enumerate(tag_lst): + tag = tag.strip() + rt_lst.append(tag + f': {logits[i]:.2f}') + return ' | '.join(rt_lst) + + return [wrap(tag, logits), wrap(tag_chinese, logits)] + + +def build_gradio(): + inputs = [ + gr.components.Image(label='image'), + gr.components.Textbox( + lines=2, + label='tag_list', + placeholder= + 'please input the categories split by keyboard "blank": ', + value=''), + gr.components.Radio(['normal', 'openset'], + label='mode', + value='normal'), + gr.components.Slider( + minimum=0, maximum=1, value=0.68, step=0.01, label='threshold') + ] + return gr.Interface( + fn=ram_inference, + inputs=inputs, + outputs=[ + gr.components.Textbox(), + gr.components.Textbox(info="it's translated from the english tags") + ]) + + +def main(): + build_gradio().launch() + + +if __name__ == '__main__': + main() diff --git a/mmpretrain/models/multimodal/ram/openset_utils.py b/mmpretrain/models/multimodal/ram/openset_utils.py new file mode 100644 index 00000000..5fa0f52e --- /dev/null +++ b/mmpretrain/models/multimodal/ram/openset_utils.py @@ -0,0 +1,212 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch + +from mmpretrain.registry import MODELS + + +def article(name): + return 'an' if name[0] in 'aeiou' else 'a' + + +def processed_name(name, rm_dot=False): + # _ for lvis + # / for obj365 + res = name.replace('_', ' ').replace('/', ' or ').lower() + if rm_dot: + res = res.rstrip('.') + return res + + +single_template = ['a photo of a {}.'] + +multiple_templates = [ + 'There is {article} {} in the scene.', + 'There is the {} in the scene.', + 'a photo of {article} {} in the scene.', + 'a photo of the {} in the scene.', + 'a photo of one {} in the scene.', + 'itap of {article} {}.', + 'itap of my {}.', # itap: I took a picture of + 'itap of the {}.', + 'a photo of {article} {}.', + 'a photo of my {}.', + 'a photo of the {}.', + 'a photo of one {}.', + 'a photo of many {}.', + 'a good photo of {article} {}.', + 'a good photo of the {}.', + 'a bad photo of {article} {}.', + 'a bad photo of the {}.', + 'a photo of a nice {}.', + 'a photo of the nice {}.', + 'a photo of a cool {}.', + 'a photo of the cool {}.', + 'a photo of a weird {}.', + 'a photo of the weird {}.', + 'a photo of a small {}.', + 'a photo of the small {}.', + 'a photo of a large {}.', + 'a photo of the large {}.', + 'a photo of a clean {}.', + 'a photo of the clean {}.', + 'a photo of a dirty {}.', + 'a photo of the dirty {}.', + 'a bright photo of {article} {}.', + 'a bright photo of the {}.', + 'a dark photo of {article} {}.', + 'a dark photo of the {}.', + 'a photo of a hard to see {}.', + 'a photo of the hard to see {}.', + 'a low resolution photo of {article} {}.', + 'a low resolution photo of the {}.', + 'a cropped photo of {article} {}.', + 'a cropped photo of the {}.', + 'a close-up photo of {article} {}.', + 'a close-up photo of the {}.', + 'a jpeg corrupted photo of {article} {}.', + 'a jpeg corrupted photo of the {}.', + 'a blurry photo of {article} {}.', + 'a blurry photo of the {}.', + 'a pixelated photo of {article} {}.', + 'a pixelated photo of the {}.', + 'a black and white photo of the {}.', + 'a black and white photo of {article} {}.', + 'a plastic {}.', + 'the plastic {}.', + 'a toy {}.', + 'the toy {}.', + 'a plushie {}.', + 'the plushie {}.', + 'a cartoon {}.', + 'the cartoon {}.', + 'an embroidered {}.', + 'the embroidered {}.', + 'a painting of the {}.', + 'a painting of a {}.', +] + +openimages_rare_unseen = [ + 'Aerial photography', 'Aircraft engine', 'Ale', 'Aloe', 'Amphibian', + 'Angling', 'Anole', 'Antique car', 'Arcade game', 'Arthropod', + 'Assault rifle', 'Athletic shoe', 'Auto racing', 'Backlighting', + 'Bagpipes', 'Ball game', 'Barbecue chicken', 'Barechested', 'Barquentine', + 'Beef tenderloin', 'Billiard room', 'Billiards', 'Bird of prey', + 'Black swan', 'Black-and-white', 'Blond', 'Boating', 'Bonbon', + 'Bottled water', 'Bouldering', 'Bovine', 'Bratwurst', 'Breadboard', + 'Briefs', 'Brisket', 'Brochette', 'Calabaza', 'Camera operator', 'Canola', + 'Childbirth', 'Chordophone', 'Church bell', 'Classical sculpture', + 'Close-up', 'Cobblestone', 'Coca-cola', 'Combat sport', 'Comics', + 'Compact car', 'Computer speaker', 'Cookies and crackers', + 'Coral reef fish', 'Corn on the cob', 'Cosmetics', 'Crocodilia', + 'Digital camera', 'Dishware', 'Divemaster', 'Dobermann', 'Dog walking', + 'Domestic rabbit', 'Domestic short-haired cat', 'Double-decker bus', + 'Drums', 'Electric guitar', 'Electric piano', 'Electronic instrument', + 'Equestrianism', 'Equitation', 'Erinaceidae', 'Extreme sport', 'Falafel', + 'Figure skating', 'Filling station', 'Fire apparatus', 'Firearm', + 'Flatbread', 'Floristry', 'Forklift truck', 'Freight transport', + 'Fried food', 'Fried noodles', 'Frigate', 'Frozen yogurt', 'Frying', + 'Full moon', 'Galleon', 'Glacial landform', 'Gliding', 'Go-kart', 'Goats', + 'Grappling', 'Great white shark', 'Gumbo', 'Gun turret', 'Hair coloring', + 'Halter', 'Headphones', 'Heavy cruiser', 'Herding', 'High-speed rail', + 'Holding hands', 'Horse and buggy', 'Horse racing', 'Hound', + 'Hunting knife', 'Hurdling', 'Inflatable', 'Jackfruit', 'Jeans', 'Jiaozi', + 'Junk food', 'Khinkali', 'Kitesurfing', 'Lawn game', 'Leaf vegetable', + 'Lechon', 'Lifebuoy', 'Locust', 'Lumpia', 'Luxury vehicle', 'Machine tool', + 'Medical imaging', 'Melee weapon', 'Microcontroller', 'Middle ages', + 'Military person', 'Military vehicle', 'Milky way', 'Miniature Poodle', + 'Modern dance', 'Molluscs', 'Monoplane', 'Motorcycling', 'Musical theatre', + 'Narcissus', 'Nest box', 'Newsagent\'s shop', 'Nile crocodile', + 'Nordic skiing', 'Nuclear power plant', 'Orator', 'Outdoor shoe', + 'Parachuting', 'Pasta salad', 'Peafowl', 'Pelmeni', 'Perching bird', + 'Performance car', 'Personal water craft', 'Pit bull', 'Plant stem', + 'Pork chop', 'Portrait photography', 'Primate', 'Procyonidae', + 'Prosciutto', 'Public speaking', 'Racewalking', 'Ramen', + 'Rear-view mirror', 'Residential area', 'Ribs', 'Rice ball', + 'Road cycling', 'Roller skating', 'Roman temple', 'Rowing', 'Rural area', + 'Sailboat racing', 'Scaled reptile', 'Scuba diving', 'Senior citizen', + 'Shallot', 'Shinto shrine', 'Shooting range', 'Siberian husky', 'Sledding', + 'Soba', 'Solar energy', 'Sport climbing', 'Sport utility vehicle', + 'Steamed rice', 'Stemware', 'Sumo', 'Surfing Equipment', 'Team sport', + 'Touring car', 'Toy block', 'Trampolining', 'Underwater diving', + 'Vegetarian food', 'Wallaby', 'Water polo', 'Watercolor paint', 'Whiskers', + 'Wind wave', 'Woodwind instrument', 'Yakitori', 'Zeppelin' +] + + +def get_clip_model(): + model = dict( + type='CLIPZeroShot', + vision_backbone=dict( + type='VisionTransformer', + arch='base', + img_size=224, + patch_size=16, + drop_rate=0., + layer_cfgs=dict(act_cfg=dict(type='mmpretrain.QuickGELU')), + pre_norm=True, + ), + projection=dict( + type='CLIPProjection', in_channels=768, out_channels=512), + text_backbone=dict( + type='CLIPTransformer', + width=512, + layers=12, + heads=8, + attn_mask=True, + ), + tokenizer=dict( + type='AutoTokenizer', + name_or_path='openai/clip-vit-base-patch16', + use_fast=False), + vocab_size=49408, + transformer_width=512, + proj_dim=512, + context_length=77, + data_preprocessor=dict( + type='MultiModalDataPreprocessor', + mean=[0.48145466 * 255, 0.4578275 * 255, 0.40821073 * 255], + std=[0.26862954 * 255, 0.26130258 * 255, 0.27577711 * 255], + to_rgb=False, + ), + ) + return MODELS.build(model) + + +def build_openset_label_embedding(categories=None, clip_ckpt_path=''): + if categories is None: + print('Categories is None, so using rare_unseen categories') + categories = openimages_rare_unseen + model = get_clip_model() + model.load_state_dict(torch.load(clip_ckpt_path)) + templates = multiple_templates + + run_on_gpu = torch.cuda.is_available() + + with torch.no_grad(): + openset_label_embedding = [] + for category in categories: + texts = [ + template.format( + processed_name(category, rm_dot=True), + article=article(category)) for template in templates + ] + texts = [ + 'This is ' + text + if text.startswith('a') or text.startswith('the') else text + for text in texts + ] + texts = model.tokenize(texts) # tokenize + if run_on_gpu: + texts = texts.cuda() + model = model.cuda() + text_embeddings = model.extract_text_feat(texts) + text_embeddings /= text_embeddings.norm(dim=-1, keepdim=True) + text_embedding = text_embeddings.mean(dim=0) + text_embedding /= text_embedding.norm() + openset_label_embedding.append(text_embedding) + openset_label_embedding = torch.stack(openset_label_embedding, dim=1) + if run_on_gpu: + openset_label_embedding = openset_label_embedding.cuda() + + openset_label_embedding = openset_label_embedding.t() + return openset_label_embedding, categories diff --git a/mmpretrain/models/multimodal/ram/ram.py b/mmpretrain/models/multimodal/ram/ram.py new file mode 100644 index 00000000..c5d22f07 --- /dev/null +++ b/mmpretrain/models/multimodal/ram/ram.py @@ -0,0 +1,332 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os +import pickle +from abc import abstractmethod +from typing import List, Optional + +import numpy as np +import torch +import torch.nn as nn +from mmengine.model import BaseModel + +from mmpretrain.registry import MODELS, TOKENIZER +from mmpretrain.structures import DataSample +from .bert import BertConfig, BertLMHeadModel, BertModel +from .openset_utils import build_openset_label_embedding +from .utils import tie_encoder_decoder_weights + + +def get_path(path): + file_path = os.path.abspath(os.path.dirname(__file__)) + if not os.path.isabs(path): + return os.path.join(file_path, path) + + +class RAM(BaseModel): + """The implementation of `RAM `_.""" + + def __init__(self, + tokenizer: dict, + vision_backbone: dict, + tag_encoder: dict, + tagging_head: dict, + text_decoder: dict, + device: str = 'cpu', + vision_width: int = 1536, + prompt='a picture of ', + threshold=0.68, + delete_tag_index=[], + tag_list='./data/ram_tag_list.pickle', + tag_list_chinese='./data/ram_tag_list_chinese.pickle', + data_preprocessor: Optional[dict] = None, + init_cfg: Optional[dict] = None): + if data_preprocessor is None: + data_preprocessor = {} + data_preprocessor.setdefault('type', 'MultiModalDataPreprocessor') + data_preprocessor = MODELS.build(data_preprocessor) + + super().__init__( + data_preprocessor=data_preprocessor, init_cfg=init_cfg) + + self.device = device + # build the visual encoder + self.visual_encoder = MODELS.build(vision_backbone) + + # build the tokenizer + self.tokenizer = TOKENIZER.build(tokenizer) + self.tokenizer.add_special_tokens({'bos_token': '[DEC]'}) + self.tokenizer.add_special_tokens( + {'additional_special_tokens': ['[ENC]']}) + self.tokenizer.enc_token_id = \ + self.tokenizer.additional_special_tokens_ids[0] + + # build the tag encoder + # encoder_config = BertConfig.from_json_file(med_config) + # encoder_config.encoder_width = 512 + encoder_config = BertConfig.from_dict(tag_encoder) + self.tag_encoder = BertModel( + config=encoder_config, add_pooling_layer=False) + + # build image-tag-text decoder + # decoder_config = BertConfig.from_json_file(med_config) + decoder_config = BertConfig.from_dict(text_decoder) + self.text_decoder = BertLMHeadModel(config=decoder_config) + + self.delete_tag_index = delete_tag_index + self.prompt = prompt + self.prompt_length = len(self.tokenizer(self.prompt).input_ids) - 1 + + # load tag list + self.tag_list = self.load_tag_list(get_path(tag_list)) + self.tag_list_chinese = self.load_tag_list(get_path(tag_list_chinese)) + + # create image-tag recognition decoder + self.threshold = threshold + self.num_class = len(self.tag_list) + # q2l_config = \ + # BertConfig.from_json_file(f'{CONFIG_PATH}/configs/q2l_config.json') + # q2l_config.encoder_width = 512 + q2l_config = BertConfig.from_dict(tagging_head) + self.tagging_head = BertModel( + config=q2l_config, add_pooling_layer=False) + self.tagging_head.resize_token_embeddings(len(self.tokenizer)) + self.label_embed = nn.Parameter( + torch.zeros(self.num_class, q2l_config.encoder_width)) + + if q2l_config.hidden_size != 512: + self.wordvec_proj = nn.Linear(512, q2l_config.hidden_size) + else: + self.wordvec_proj = nn.Identity() + + self.fc = nn.Linear(q2l_config.hidden_size, 1) + + self.del_selfattention() + + # share weights of the lowest 2-layer of + # "image-tag interaction encoder" with + # the "image-tag recogntion decoder" + tie_encoder_decoder_weights(self.tag_encoder, self.tagging_head, '', + ' ') + self.image_proj = nn.Linear(vision_width, 512) + # self.label_embed = nn.Parameter(torch.load( + # f'{CONFIG_PATH}/data/textual_label_embedding.pth', + # map_location='cpu').float()) + + # adjust thresholds for some tags + self.class_threshold = torch.ones(self.num_class) * self.threshold + ram_class_threshold_path = get_path( + './data/ram_tag_list_threshold.pickle') + with open(ram_class_threshold_path, 'rb') as f: + ram_class_threshold = pickle.load(f) + for key, value in enumerate(ram_class_threshold): + self.class_threshold[key] = value + + def load_tag_list(self, tag_list_file): + with open(tag_list_file, 'rb') as f: + tag_list = pickle.load(f) + tag_list = np.array(tag_list) + return tag_list + + # delete self-attention layer of image-tag recognition decoder + # to reduce computation, follower Query2Label + def del_selfattention(self): + del self.tagging_head.embeddings + for layer in self.tagging_head.encoder.layer: + del layer.attention + + def get_label_embed(self): + return torch.nn.functional.relu(self.wordvec_proj(self.label_embed)) + + def extract_visual_feature(self, images): + image_embeds = self.visual_encoder(images)[0] + image_embeds = image_embeds.flatten(2, 3) + attn_pool = nn.AdaptiveAvgPool1d(1) + cls_token = attn_pool(image_embeds).permute(0, 2, 1).contiguous() + image_embeds = image_embeds.permute(0, 2, 1).contiguous() + image_embeds = torch.cat([cls_token, image_embeds], dim=1) + image_embeds = self.image_proj(image_embeds) + image_atts = torch.ones( + image_embeds.size()[:-1], dtype=torch.long).to(images.device) + return image_embeds, image_atts + + def image2tag(self, label_embed, image_embeds, image_atts): + # recognized image tags using image-tag recogntiion decoder + # image_cls_embeds = image_embeds[:, 0, :] + image_spatial_embeds = image_embeds[:, 1:, :] + + bs = image_spatial_embeds.shape[0] + label_embed = label_embed.unsqueeze(0).repeat(bs, 1, 1) + tagging_embed = self.tagging_head( + encoder_embeds=label_embed, + encoder_hidden_states=image_embeds, + encoder_attention_mask=image_atts, + return_dict=False, + mode='tagging', + ) + + logits = self.fc(tagging_embed[0]).squeeze(-1) + return logits + + def forward( + self, + images: torch.Tensor, + data_samples: Optional[list] = None, + mode: str = 'predict', + **kwargs, + ): + if mode == 'predict': + return self.predict(images, data_samples, **kwargs) + else: + raise RuntimeError(f'Invalid mode "{mode}".') + + @abstractmethod + def predict(self, + images: torch.Tensor, + data_samples: DataSample = None) -> DataSample: + raise NotImplementedError + + +@MODELS.register_module() +class RAMNormal(RAM): + + def __init__(self, + tokenizer: dict, + vision_backbone: dict, + tag_encoder: dict, + tagging_head: dict, + text_decoder: dict, + device: str = 'cpu', + vision_width: int = 1536, + prompt='a picture of ', + threshold=0.68, + delete_tag_index=[], + tag_list='./data/ram_tag_list.pickle', + tag_list_chinese='./data/ram_tag_list_chinese.pickle', + data_preprocessor: Optional[dict] = None, + init_cfg: Optional[dict] = None): + super().__init__( + tokenizer, + vision_backbone, + tag_encoder, + tagging_head, + text_decoder, + device, + vision_width, + prompt, + threshold, + delete_tag_index, + tag_list, + tag_list_chinese, + data_preprocessor, + init_cfg, + ) + + def tag_process(self, logits): + targets = torch.where( + torch.sigmoid(logits) > self.class_threshold.to(logits.device), + torch.tensor(1.0).to(logits.device), + torch.zeros(self.num_class).to(logits.device)) + + tag = targets.cpu().numpy() + tag[:, self.delete_tag_index] = 0 + tag_output = [] + tag_output_chinese = [] + logits_output = [] + + bs = logits.shape[0] + for b in range(bs): + index = np.argwhere(tag[b] == 1) + token = self.tag_list[index].squeeze(axis=1) + logits_output.append( + torch.sigmoid(logits)[b][index[:, 0]].cpu().numpy()) + tag_output.append(' | '.join(token)) + token_chinese = self.tag_list_chinese[index].squeeze(axis=1) + tag_output_chinese.append(' | '.join(token_chinese)) + + return [(tag_output, tag_output_chinese), logits_output] + + def predict(self, + images: torch.Tensor, + data_samples: DataSample = None) -> DataSample: + self.eval() + self.to(self.device) + images = images.to(self.device) + label_embed = self.get_label_embed() + image_embeds, image_atts = self.extract_visual_feature(images) + logits = self.image2tag(label_embed, image_embeds, image_atts) + tag_output, logits_output = self.tag_process(logits) + data_samples.set_field(logits_output, 'logits_output') + data_samples.set_field(tag_output, 'tag_output') + return data_samples + + +@MODELS.register_module() +class RAMOpenset(RAMNormal): + + def __init__(self, + tokenizer: dict, + vision_backbone: dict, + tag_encoder: dict, + tagging_head: dict, + text_decoder: dict, + device: str = 'cpu', + vision_width: int = 1536, + prompt='a picture of ', + threshold=0.68, + delete_tag_index=[], + tag_list='./data/ram_tag_list.pickle', + tag_list_chinese='./data/ram_tag_list_chinese.pickle', + data_preprocessor: Optional[dict] = None, + init_cfg: Optional[dict] = None): + super().__init__( + tokenizer, + vision_backbone, + tag_encoder, + tagging_head, + text_decoder, + device, + vision_width, + prompt, + threshold, + delete_tag_index, + tag_list, + tag_list_chinese, + data_preprocessor, + init_cfg, + ) + + def set_openset(self, + categories: List[str] = None, + clip_ckpt: str = '', + threshold: float = 0.68): + openset_label_embedding, openset_categories = \ + build_openset_label_embedding( + categories, clip_ckpt + ) + self.tag_list = np.array(openset_categories) + self.label_embed = nn.Parameter(openset_label_embedding.float()) + self.num_class = len(openset_categories) + + # the threshold for unseen categories is often lower + self.class_threshold = torch.ones(self.num_class) * threshold + + def tag_process(self, logits): + targets = torch.where( + torch.sigmoid(logits) > self.class_threshold.to(logits.device), + torch.tensor(1.0).to(logits.device), + torch.zeros(self.num_class).to(logits.device)) + + tag = targets.cpu().numpy() + tag[:, self.delete_tag_index] = 0 + + bs = logits.shape[0] + tag_output = [] + logits_output = [] + for b in range(bs): + index = np.argwhere(tag[b] == 1) + token = self.tag_list[index].squeeze(axis=1) + logits_output.append( + torch.sigmoid(logits)[b][index[:, 0]].cpu().numpy()) + tag_output.append(' | '.join(token)) + + return [(tag_output, [None]), logits_output] diff --git a/mmpretrain/models/multimodal/ram/run/__init__.py b/mmpretrain/models/multimodal/ram/run/__init__.py new file mode 100644 index 00000000..ef101fec --- /dev/null +++ b/mmpretrain/models/multimodal/ram/run/__init__.py @@ -0,0 +1 @@ +# Copyright (c) OpenMMLab. All rights reserved. diff --git a/mmpretrain/models/multimodal/ram/run/inference.py b/mmpretrain/models/multimodal/ram/run/inference.py new file mode 100644 index 00000000..da5afcf5 --- /dev/null +++ b/mmpretrain/models/multimodal/ram/run/inference.py @@ -0,0 +1,29 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch + + +def inference_ram(sample, model): + + with torch.no_grad(): + result = model.test_step(sample) + + return result + + +def inference_ram_openset(sample, model): + with torch.no_grad(): + result = model.test_step(sample) + + return result + + +def inference(sample, model, transforms, mode='normal'): + sample = transforms(sample) + if sample['inputs'].ndim == 3: + sample['inputs'] = sample['inputs'].unsqueeze(dim=0) + assert mode in ['normal', 'openset' + ], 'mode of inference must be "normal" or "openset"' + if mode == 'normal': + return inference_ram(sample, model) + else: + return inference_ram_openset(sample, model) diff --git a/mmpretrain/models/multimodal/ram/utils.py b/mmpretrain/models/multimodal/ram/utils.py new file mode 100644 index 00000000..32cb115b --- /dev/null +++ b/mmpretrain/models/multimodal/ram/utils.py @@ -0,0 +1,87 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List + +from torch import nn + + +def tie_encoder_decoder_weights(encoder: nn.Module, decoder: nn.Module, + base_model_prefix: str, skip_key: str): + uninitialized_encoder_weights: List[str] = [] + if decoder.__class__ != encoder.__class__: + print(f'''{decoder.__class__} and {encoder.__class__} are not equal. + In this case make sure that + all encoder weights are correctly initialized.''') + + def tie_encoder_to_decoder_recursively( + decoder_pointer: nn.Module, + encoder_pointer: nn.Module, + module_name: str, + uninitialized_encoder_weights: List[str], + skip_key: str, + depth=0, + ): + assert isinstance(decoder_pointer, nn.Module) and isinstance( + encoder_pointer, nn.Module + ), f'{decoder_pointer} and {encoder_pointer}' + \ + 'have to be of type torch.nn.Module' + if hasattr(decoder_pointer, 'weight') and skip_key not in module_name: + assert hasattr(encoder_pointer, 'weight') + encoder_pointer.weight = decoder_pointer.weight + if hasattr(decoder_pointer, 'bias'): + assert hasattr(encoder_pointer, 'bias') + encoder_pointer.bias = decoder_pointer.bias + print(module_name + ' is tied') + return + + encoder_modules = encoder_pointer._modules + decoder_modules = decoder_pointer._modules + if len(decoder_modules) > 0: + assert (len(encoder_modules) > + 0), f'''Encoder module {encoder_pointer} + does not match decoder module {decoder_pointer}''' + + all_encoder_weights = set([ + module_name + '/' + sub_name + for sub_name in encoder_modules.keys() + ]) + encoder_layer_pos = 0 + for name, module in decoder_modules.items(): + if name.isdigit(): + encoder_name = str(int(name) + encoder_layer_pos) + decoder_name = name + if not isinstance( + decoder_modules[decoder_name], + type(encoder_modules[encoder_name])) and len( + encoder_modules) != len(decoder_modules): + # this can happen if the name corresponds to + # the position in a list module list of layers + # in this case the decoder has added a + # cross-attention that the encoder doesn't have + # thus skip this step and + # subtract one layer pos from encoder + encoder_layer_pos -= 1 + continue + elif name not in encoder_modules: + continue + elif depth > 500: + raise ValueError( + '''Max depth of recursive function `tie_encoder_to_decoder` reached. + It seems that there is a circular dependency + between two or more `nn.Modules` of your model.''') + else: + decoder_name = encoder_name = name + tie_encoder_to_decoder_recursively( + decoder_modules[decoder_name], + encoder_modules[encoder_name], + module_name + '/' + name, + uninitialized_encoder_weights, + skip_key, + depth=depth + 1, + ) + all_encoder_weights.remove(module_name + '/' + encoder_name) + + uninitialized_encoder_weights += list(all_encoder_weights) + + # tie weights recursively + tie_encoder_to_decoder_recursively(decoder, encoder, base_model_prefix, + uninitialized_encoder_weights, skip_key) diff --git a/mmpretrain/models/utils/tokenizer.py b/mmpretrain/models/utils/tokenizer.py index 5b8a324b..fddda432 100644 --- a/mmpretrain/models/utils/tokenizer.py +++ b/mmpretrain/models/utils/tokenizer.py @@ -12,6 +12,7 @@ from .huggingface import register_hf_tokenizer register_hf_tokenizer(AutoTokenizer) register_hf_tokenizer(LlamaTokenizer) +register_hf_tokenizer(BertTokenizer) @register_hf_tokenizer() diff --git a/tools/model_converters/ram2mmpretrain.py b/tools/model_converters/ram2mmpretrain.py new file mode 100644 index 00000000..5ee3b476 --- /dev/null +++ b/tools/model_converters/ram2mmpretrain.py @@ -0,0 +1,117 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import argparse +import os.path as osp +from collections import OrderedDict +from copy import deepcopy + +import mmengine +import torch +from mmengine.runner import CheckpointLoader + + +def convert_swin(ckpt): + new_ckpt = OrderedDict() + convert_mapping = dict() + + def correct_unfold_reduction_order(x): + out_channel, in_channel = x.shape + x = x.reshape(out_channel, 4, in_channel // 4) + x = x[:, [0, 2, 1, 3], :].transpose(1, + 2).reshape(out_channel, in_channel) + return x + + def correct_unfold_norm_order(x): + in_channel = x.shape[0] + x = x.reshape(4, in_channel // 4) + x = x[[0, 2, 1, 3], :].transpose(0, 1).reshape(in_channel) + return x + + for k, v in ckpt.items(): + if 'attn_mask' in k: + continue + if k.startswith('head'): + continue + elif k.startswith('layers'): + new_v = v + if 'attn.' in k: + new_k = k.replace('attn.', 'attn.w_msa.') + elif 'mlp.' in k: + if 'mlp.fc1.' in k: + new_k = k.replace('mlp.fc1.', 'ffn.layers.0.0.') + elif 'mlp.fc2.' in k: + new_k = k.replace('mlp.fc2.', 'ffn.layers.1.') + else: + new_k = k.replace('mlp.', 'ffn.') + elif 'downsample' in k: + new_k = k + if 'reduction.' in k: + new_v = correct_unfold_reduction_order(v) + elif 'norm.' in k: + new_v = correct_unfold_norm_order(v) + else: + new_k = k + new_k = new_k.replace('layers', 'stages', 1) + elif k.startswith('patch_embed'): + new_v = v + if 'proj' in k: + new_k = k.replace('proj', 'projection') + else: + new_k = k + elif k.startswith('norm'): + new_v = v + new_k = k.replace('norm', 'norm3') + else: + new_v = v + new_k = k + + new_ckpt[new_k] = new_v + convert_mapping[k] = new_k + + return new_ckpt, convert_mapping + + +def main(): + parser = argparse.ArgumentParser( + description='Convert keys in official pretrained RAM models to' + 'MMPretrain style.') + parser.add_argument('src', help='src model path or url') + # The dst path must be a full path of the new checkpoint. + parser.add_argument('dst', help='save path') + args = parser.parse_args() + + checkpoint = CheckpointLoader.load_checkpoint(args.src, map_location='cpu') + if 'state_dict' in checkpoint: + state_dict = checkpoint['state_dict'] + elif 'model' in checkpoint: + state_dict = checkpoint['model'] + else: + state_dict = checkpoint + + visual_ckpt = OrderedDict() + for key in state_dict: + if key.startswith('visual_encoder.'): + new_key = key.replace('visual_encoder.', '') + visual_ckpt[new_key] = state_dict[key] + + new_visual_ckpt, convert_mapping = convert_swin(visual_ckpt) + new_ckpt = deepcopy(state_dict) + for key in state_dict: + if key.startswith('visual_encoder.'): + if 'attn_mask' in key: + del new_ckpt[key] + continue + del new_ckpt[key] + old_key = key.replace('visual_encoder.', '') + new_ckpt[key.replace(old_key, + convert_mapping[old_key])] = deepcopy( + new_visual_ckpt[key.replace( + old_key, + convert_mapping[old_key]).replace( + 'visual_encoder.', '')]) + + mmengine.mkdir_or_exist(osp.dirname(args.dst)) + torch.save(new_ckpt, args.dst) + + +if __name__ == '__main__': + main()