# Copyright (c) OpenMMLab. All rights reserved. # Modified from https://github.com/zejiangh/MILAN from collections import OrderedDict from typing import Optional, Tuple, Union import numpy as np import torch from mmengine.logging import MMLogger from torch import nn class LayerNorm(nn.LayerNorm): """Subclass torch's LayerNorm to handle fp16.""" def forward(self, x: torch.Tensor) -> torch.Tensor: """Forward function.""" orig_type = x.dtype ret = super().forward(x.type(torch.float32)) return ret.type(orig_type) class QuickGELU(nn.Module): """A faster version of GELU.""" def forward(self, x: torch.Tensor) -> torch.Tensor: """Forward function.""" return x * torch.sigmoid(1.702 * x) class ResidualAttentionBlock(nn.Module): """Residual Attention Block (RAB). This module implements the same function as the MultiheadAttention in MMClassification, but with a different interface, which is mainly used in CLIP. Args: d_model (int): The feature dimension. n_head (int): The number of attention heads. attn_mask (torch.Tensor, optional): The attention mask. Defaults to None. """ def __init__(self, d_model: int, n_head: int, attn_mask: Optional[torch.Tensor] = None, return_attention: bool = False) -> None: super().__init__() self.attn = nn.MultiheadAttention(d_model, n_head) self.ln_1 = LayerNorm(d_model) self.mlp = nn.Sequential( OrderedDict([('c_fc', nn.Linear(d_model, d_model * 4)), ('gelu', QuickGELU()), ('c_proj', nn.Linear(d_model * 4, d_model))])) self.ln_2 = LayerNorm(d_model) self.attn_mask = attn_mask self.return_attention = return_attention def attention(self, x: torch.Tensor) -> torch.Tensor: """Attention function.""" self.attn_mask = self.attn_mask.to( dtype=x.dtype, device=x.device) if self.attn_mask is not None else None if self.return_attention: return self.attn( x, x, x, need_weights=self.return_attention, attn_mask=self.attn_mask) else: return self.attn( x, x, x, need_weights=self.return_attention, attn_mask=self.attn_mask)[0] def forward( self, x: torch.Tensor ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: """Forward function.""" if self.return_attention: x_, attention = self.attention(self.ln_1(x)) x = x + x_ x = x + self.mlp(self.ln_2(x)) return x, attention else: x = x + self.attention(self.ln_1(x)) x = x + self.mlp(self.ln_2(x)) return x class Transformer(nn.Module): """Transformer. Both visual and text branches use this transformer. Args: width (int): The feature dimension. layers (int): The number of layers. heads (int): The number of attention heads. attn_mask (torch.Tensor, optional): The attention mask. """ def __init__(self, width: int, layers: int, heads: int, attn_mask: Optional[torch.Tensor] = None) -> None: super().__init__() self.width = width self.layers = layers self.resblocks = nn.ModuleList() for _ in range(layers - 1): self.resblocks.append( ResidualAttentionBlock(width, heads, attn_mask)) self.resblocks.append( ResidualAttentionBlock( width, heads, attn_mask, return_attention=True)) def forward( self, x: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Forward function.""" z = [] for idx, blk in enumerate(self.resblocks): if idx < self.layers - 1: x = blk(x) z.append(x.permute(1, 0, 2)) else: x, attention = blk(x) z.append(x.permute(1, 0, 2)) return x, attention, z class VisionTransformer(nn.Module): """Vision Transformer for CLIP. Args: input_resolution (int): The image size. patch_size (int): The patch size. width (int): The feature dimension. layers (int): The number of layers. heads (int): The number of attention heads. out_dim (int): The output dimension. fineturn (bool): Whether to fineturn the model. average_target (bool): Whether to average the target. """ def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int, finetune=False, average_targets: int = 1) -> None: super().__init__() self.input_resolution = input_resolution self.output_dim = output_dim self.conv1 = nn.Conv2d( in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False) scale = width**-0.5 self.class_embedding = nn.Parameter(scale * torch.randn(width)) self.positional_embedding = nn.Parameter(scale * torch.randn( (input_resolution // patch_size)**2 + 1, width)) self.ln_pre = LayerNorm(width) self.transformer = Transformer(width, layers, heads) self.finetune = finetune if finetune is False: self.ln_post = LayerNorm(width) self.proj = nn.Parameter(scale * torch.randn(width, output_dim)) self.average_targets = average_targets def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """Forward function.""" x = self.conv1(x) # shape = [*, width, grid, grid] x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] x = torch.cat([ self.class_embedding.to(x.dtype) + torch.zeros( x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x ], dim=1) # shape = [*, grid ** 2 + 1, width] x = x + self.positional_embedding.to(x.dtype) x = self.ln_pre(x) x = x.permute(1, 0, 2) # NLD -> LND x, attention, z = self.transformer(x) x = x.permute(1, 0, 2) # LND -> NLD x = self.ln_post(x) if self.proj is not None: x = x @ self.proj return x, attention class CLIP(nn.Module): """CLIP. Args: embed_dim (int): The embedding dimension. image_resolution (int): The image size. vision_layers (int): The number of layers in the vision transformer. vision_width (int): The feature dimension in the vision transformer. vision_patch_size (int): The patch size in the vision transformer. context_length (int): The context length. vocab_size (int): The vocabulary size. transformer_width (int): The feature dimension in the text transformer. transformer_heads (int): The number of attention heads in the text transformer. transformer_layers (int): The number of layers in the text transformer. fineturn (bool): Whether to fineturn the model. average_target (bool): Whether to average the target. """ def __init__( self, embed_dim: int, image_resolution: int, vision_layers: Union[Tuple[int, int, int, int], int], vision_width: int, vision_patch_size: int, context_length: int, vocab_size: int, transformer_width: int, transformer_heads: int, transformer_layers: int, finetune: bool = False, average_targets: int = 1, ) -> None: super().__init__() self.context_length = context_length vision_heads = vision_width // 64 self.visual = VisionTransformer( input_resolution=image_resolution, patch_size=vision_patch_size, width=vision_width, layers=vision_layers, heads=vision_heads, output_dim=embed_dim, finetune=finetune, average_targets=average_targets, ) self.transformer = Transformer( width=transformer_width, layers=transformer_layers, heads=transformer_heads, attn_mask=self.build_attention_mask()) self.vocab_size = vocab_size self.token_embedding = nn.Embedding(vocab_size, transformer_width) self.positional_embedding = nn.Parameter( torch.empty(self.context_length, transformer_width)) self.ln_final = LayerNorm(transformer_width) self.text_projection = nn.Parameter( torch.empty(transformer_width, embed_dim)) self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) self.initialize_parameters() def initialize_parameters(self) -> None: """Initialize the parameters. The pretrained weight will override the initialized parameters by this function. """ nn.init.normal_(self.token_embedding.weight, std=0.02) nn.init.normal_(self.positional_embedding, std=0.01) proj_std = (self.transformer.width**-0.5) * ( (2 * self.transformer.layers)**-0.5) attn_std = self.transformer.width**-0.5 fc_std = (2 * self.transformer.width)**-0.5 for block in self.transformer.resblocks: nn.init.normal_(block.attn.in_proj_weight, std=attn_std) nn.init.normal_(block.attn.out_proj.weight, std=proj_std) nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) if self.text_projection is not None: nn.init.normal_( self.text_projection, std=self.transformer.width**-0.5) def build_attention_mask(self) -> torch.Tensor: """Build the attention mask.""" # lazily create causal attention mask, with full attention between the # vision tokens pytorch uses additive attention mask; fill with -inf mask = torch.empty(self.context_length, self.context_length) mask.fill_(float('-inf')) mask.triu_(1) # zero out the lower diagonal return mask @property def dtype(self) -> torch.dtype: """Get the dtype.""" return self.visual.conv1.weight.dtype def encode_image(self, image: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """Encode the image. Get the feature and attention mask from the last layer of the visual branch of CLIP. Args: image (torch.Tensor): The image tensor with shape NCHW. Returns: Tuple[torch.Tensor, torch.Tensor]: The feature and attention mask. """ return self.visual(image.type(self.dtype)) def build_clip_model(state_dict: dict, finetune: bool = False, average_targets: int = 1) -> nn.Module: """Build the CLIP model. Args: state_dict (dict): The pretrained state dict. finetune (bool): Whether to fineturn the model. average_targets (bool): Whether to average the target. Returns: nn.Module: The CLIP model. """ vit = 'visual.proj' in state_dict if vit: vision_width = state_dict['visual.conv1.weight'].shape[0] vision_layers = len([ k for k in state_dict.keys() if k.startswith('visual.') and k.endswith('.attn.in_proj_weight') ]) vision_patch_size = state_dict['visual.conv1.weight'].shape[-1] grid_size = round( (state_dict['visual.positional_embedding'].shape[0] - 1)**0.5) image_resolution = vision_patch_size * grid_size embed_dim = state_dict['text_projection'].shape[1] context_length = state_dict['positional_embedding'].shape[0] vocab_size = state_dict['token_embedding.weight'].shape[0] transformer_width = state_dict['ln_final.weight'].shape[0] transformer_heads = transformer_width // 64 transformer_layers = len( set( k.split('.')[2] for k in state_dict if k.startswith('transformer.resblocks'))) model = CLIP( embed_dim, image_resolution, vision_layers, vision_width, vision_patch_size, context_length, vocab_size, transformer_width, transformer_heads, transformer_layers, finetune, average_targets, ) for key in ['input_resolution', 'context_length', 'vocab_size']: if key in state_dict: del state_dict[key] msg = model.load_state_dict(state_dict, strict=False) MMLogger.get_current_instance().info(f'Load CLIP model: {msg}') return model.eval()