diff --git a/timm/layers/__init__.py b/timm/layers/__init__.py index c71ff30c..23c6f908 100644 --- a/timm/layers/__init__.py +++ b/timm/layers/__init__.py @@ -1,6 +1,7 @@ from .activations import * from .adaptive_avgmax_pool import \ adaptive_avgmax_pool2d, select_adaptive_pool2d, AdaptiveAvgMaxPool2d, SelectAdaptivePool2d +from .attention import Attention, AttentionRope from .attention2d import MultiQueryAttention2d, Attention2d, MultiQueryAttentionV2 from .attention_pool import AttentionPoolLatent from .attention_pool2d import AttentionPool2d, RotAttentionPool2d, RotaryEmbedding @@ -41,6 +42,7 @@ from .norm_act import BatchNormAct2d, GroupNormAct, GroupNorm1Act, LayerNormAct, from .padding import get_padding, get_same_padding, pad_same from .patch_dropout import PatchDropout from .patch_embed import PatchEmbed, PatchEmbedWithSize, resample_patch_embed +from .pool1d import global_pool_nlc from .pool2d_same import AvgPool2dSame, create_pool2d from .pos_embed import resample_abs_pos_embed, resample_abs_pos_embed_nhwc from .pos_embed_rel import RelPosMlp, RelPosBias, RelPosBiasTf, gen_relative_position_index, gen_relative_log_coords, \ diff --git a/timm/layers/attention.py b/timm/layers/attention.py new file mode 100644 index 00000000..8e95a002 --- /dev/null +++ b/timm/layers/attention.py @@ -0,0 +1,212 @@ +from typing import Final, Optional, Type + +import torch +from torch import nn as nn +from torch.nn import functional as F + +from .config import use_fused_attn +from .pos_embed_sincos import apply_rot_embed_cat + + +class Attention(nn.Module): + """Standard Multi-head Self Attention module with QKV projection. + + This module implements the standard multi-head attention mechanism used in transformers. + It supports both the fused attention implementation (scaled_dot_product_attention) for + efficiency when available, and a manual implementation otherwise. The module includes + options for QK normalization, attention dropout, and projection dropout. + """ + fused_attn: Final[bool] + + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = False, + qk_norm: bool = False, + proj_bias: bool = True, + attn_drop: float = 0., + proj_drop: float = 0., + norm_layer: Type[nn.Module] = nn.LayerNorm, + ) -> None: + """Initialize the Attention module. + + Args: + dim: Input dimension of the token embeddings + num_heads: Number of attention heads + qkv_bias: Whether to use bias in the query, key, value projections + qk_norm: Whether to apply normalization to query and key vectors + proj_bias: Whether to use bias in the output projection + attn_drop: Dropout rate applied to the attention weights + proj_drop: Dropout rate applied after the output projection + norm_layer: Normalization layer constructor for QK normalization if enabled + """ + super().__init__() + assert dim % num_heads == 0, 'dim should be divisible by num_heads' + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.scale = self.head_dim ** -0.5 + self.fused_attn = use_fused_attn() + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() + self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim, bias=proj_bias) + self.proj_drop = nn.Dropout(proj_drop) + + def forward( + self, + x: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) + q, k, v = qkv.unbind(0) + q, k = self.q_norm(q), self.k_norm(k) + + if self.fused_attn: + x = F.scaled_dot_product_attention( + q, k, v, + attn_mask=attn_mask, + dropout_p=self.attn_drop.p if self.training else 0., + ) + else: + q = q * self.scale + attn = q @ k.transpose(-2, -1) + if attn_mask is not None: + attn = attn + attn_mask + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + x = attn @ v + + x = x.transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class AttentionRope(nn.Module): + """ A Self Attention module with ROPE support. + + Includes options for: + * QK normalization option + * Attention output (scale) normalization + * Fused or unfused QKV projection support + """ + fused_attn: torch.jit.Final[bool] + + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = True, + qkv_fused: bool = True, + num_prefix_tokens: int = 1, + attn_drop: float = 0., + proj_drop: float = 0., + attn_head_dim: Optional[int] = None, + norm_layer: Type[nn.Module] = None, + qk_norm: bool = False, + scale_norm: bool = False, + ): + """Initialize the Attention module. + + Args: + dim: Input dimension of the token embeddings + num_heads: Number of attention heads + qkv_bias: Whether to add a bias term to the query, key, and value projections + num_prefix_tokens: Number of reg/cls tokens at the beginning of the sequence that + should not have position embeddings applied + attn_drop: Dropout rate for attention weights + proj_drop: Dropout rate for the output projection + attn_head_dim: Dimension of each attention head (if None, computed as dim // num_heads) + norm_layer: Normalization layer constructor to use for QK and scale normalization + qk_norm: Enable normalization of query (Q) and key (K) vectors with norm_layer + scale_norm: Enable normalization (scaling) of attention output with norm_layer + """ + super().__init__() + if scale_norm or qk_norm: + assert norm_layer is not None, 'norm_layer must be provided if qk_norm or scale_norm is True' + self.num_heads = num_heads + head_dim = dim // num_heads + if attn_head_dim is not None: + head_dim = attn_head_dim + attn_dim = head_dim * self.num_heads + self.scale = head_dim ** -0.5 + self.num_prefix_tokens = num_prefix_tokens + self.fused_attn = use_fused_attn() + + if qkv_fused: + self.qkv = nn.Linear(dim, attn_dim * 3, bias=qkv_bias) + self.q_proj = self.k_proj = self.v_proj = None + else: + self.qkv = None + self.q_proj = nn.Linear(dim, attn_dim, bias=qkv_bias) + self.k_proj = nn.Linear(dim, attn_dim, bias=qkv_bias) + self.v_proj = nn.Linear(dim, attn_dim, bias=qkv_bias) + + self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() + self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() + self.attn_drop = nn.Dropout(attn_drop) + self.norm = norm_layer(attn_dim) if scale_norm else nn.Identity() + self.proj = nn.Linear(attn_dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward( + self, + x, + rope: Optional[torch.Tensor] = None, + attn_mask: Optional[torch.Tensor] = None, + ): + """Forward pass for the attention module. + + Args: + x: Input tensor of shape (batch_size, sequence_length, embedding_dim) + rope: Rotary position embeddings tensor for position-aware attention + attn_mask: Optional attention mask to apply during attention computation + + Returns: + Tensor of shape (batch_size, sequence_length, embedding_dim) + """ + B, N, C = x.shape + + if self.qkv is not None: + qkv = self.qkv(x) + qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) + q, k, v = qkv.unbind(0) # B, num_heads, N, head_dim + else: + q = self.q_proj(x).reshape(B, N, self.num_heads, -1).transpose(1, 2) # B, num_heads, N, C + k = self.k_proj(x).reshape(B, N, self.num_heads, -1).transpose(1, 2) + v = self.v_proj(x).reshape(B, N, self.num_heads, -1).transpose(1, 2) + + q, k = self.q_norm(q), self.k_norm(k) + + if rope is not None: + npt = self.num_prefix_tokens + q = torch.cat([q[:, :, :npt, :], apply_rot_embed_cat(q[:, :, npt:, :], rope)], dim=2).type_as(v) + k = torch.cat([k[:, :, :npt, :], apply_rot_embed_cat(k[:, :, npt:, :], rope)], dim=2).type_as(v) + + if self.fused_attn: + x = F.scaled_dot_product_attention( + q, k, v, + attn_mask=attn_mask, + dropout_p=self.attn_drop.p if self.training else 0., + ) + else: + q = q * self.scale + attn = (q @ k.transpose(-2, -1)) + + if attn_mask is not None: + attn_mask = attn_mask.to(torch.bool) + attn = attn.masked_fill(~attn_mask[:, None, None, :], float("-inf")) + attn = attn.softmax(dim=-1) + + attn = self.attn_drop(attn) + x = attn @ v + + x = x.transpose(1, 2).reshape(B, N, C) + x = self.norm(x) + x = self.proj(x) + x = self.proj_drop(x) + return x diff --git a/timm/layers/attention_pool.py b/timm/layers/attention_pool.py index 2e87566a..c2591a3b 100644 --- a/timm/layers/attention_pool.py +++ b/timm/layers/attention_pool.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import Optional, Type import torch import torch.nn as nn @@ -28,8 +28,8 @@ class AttentionPoolLatent(nn.Module): latent_dim: int = None, pos_embed: str = '', pool_type: str = 'token', - norm_layer: Optional[nn.Module] = None, - act_layer: Optional[nn.Module] = nn.GELU, + norm_layer: Optional[Type[nn.Module]] = None, + act_layer: Optional[Type[nn.Module]] = nn.GELU, drop: float = 0.0, ): super().__init__() diff --git a/timm/layers/pool1d.py b/timm/layers/pool1d.py new file mode 100644 index 00000000..e20cf23b --- /dev/null +++ b/timm/layers/pool1d.py @@ -0,0 +1,26 @@ +import torch + + +def global_pool_nlc( + x: torch.Tensor, + pool_type: str = 'token', + num_prefix_tokens: int = 1, + reduce_include_prefix: bool = False, +): + if not pool_type: + return x + + if pool_type == 'token': + x = x[:, 0] # class token + else: + x = x if reduce_include_prefix else x[:, num_prefix_tokens:] + if pool_type == 'avg': + x = x.mean(dim=1) + elif pool_type == 'avgmax': + x = 0.5 * (x.amax(dim=1) + x.mean(dim=1)) + elif pool_type == 'max': + x = x.amax(dim=1) + else: + assert not pool_type, f'Unknown pool type {pool_type}' + + return x \ No newline at end of file diff --git a/timm/layers/pos_embed_sincos.py b/timm/layers/pos_embed_sincos.py index 5bb31af5..f8f169d5 100644 --- a/timm/layers/pos_embed_sincos.py +++ b/timm/layers/pos_embed_sincos.py @@ -87,6 +87,8 @@ def build_fourier_pos_embed( include_grid: bool = False, in_pixels: bool = True, ref_feat_shape: Optional[List[int]] = None, + grid_offset: float = 0., + grid_indexing: str = 'ij', dtype: torch.dtype = torch.float32, device: Optional[torch.device] = None, ) -> List[torch.Tensor]: @@ -102,6 +104,8 @@ def build_fourier_pos_embed( include_grid: Include the spatial grid in output. in_pixels: Output in pixel freq. ref_feat_shape: Reference feature shape for resize / fine-tune. + grid_offset: Constant offset to add to grid for non-pixel freq. + grid_indexing: Indexing mode for meshgrid ('ij' or 'xy') dtype: Output dtype. device: Output device. @@ -130,15 +134,21 @@ def build_fourier_pos_embed( dtype = bands.dtype if in_pixels: - t = [torch.linspace(-1., 1., steps=s, device=device, dtype=torch.float32) for s in feat_shape] + t = [ + torch.linspace(-1., 1., steps=s, device=device, dtype=torch.float32) + for s in feat_shape + ] else: - t = [torch.arange(s, device=device, dtype=torch.int64).to(torch.float32) for s in feat_shape] + t = [ + torch.arange(s, device=device, dtype=torch.int64).to(torch.float32) + grid_offset + for s in feat_shape + ] if ref_feat_shape is not None: # eva's scheme for resizing rope embeddings (ref shape = pretrain) t = [x / f * r for x, f, r in zip(t, feat_shape, ref_feat_shape)] - grid = torch.stack(ndgrid(t), dim=-1) + grid = torch.stack(torch.meshgrid(t, indexing=grid_indexing), dim=-1) grid = grid.unsqueeze(-1) pos = grid * bands @@ -229,6 +239,8 @@ def build_rotary_pos_embed( linear_bands: bool = False, in_pixels: bool = True, ref_feat_shape: Optional[List[int]] = None, + grid_offset: float = 0., + grid_indexing: str = 'ij', dtype: torch.dtype = torch.float32, device: Optional[torch.device] = None, ): @@ -242,6 +254,9 @@ def build_rotary_pos_embed( temperature: Temperature (inv freq) for non-pixel mode linear_bands: Linearly (instead of log) spaced bands for pixel mode in_pixels: Pixel vs language (inv freq) mode. + ref_feat_shape: Reference feature shape for resize / fine-tune. + grid_offset: Constant offset to add to grid for non-pixel freq. + grid_indexing: Indexing mode for meshgrid ('ij' or 'xy') dtype: Output dtype. device: Output device. @@ -257,6 +272,8 @@ def build_rotary_pos_embed( linear_bands=linear_bands, in_pixels=in_pixels, ref_feat_shape=ref_feat_shape, + grid_offset=grid_offset, + grid_indexing=grid_indexing, device=device, dtype=dtype, ) @@ -289,6 +306,8 @@ class RotaryEmbedding(nn.Module): linear_bands: bool = False, feat_shape: Optional[List[int]] = None, ref_feat_shape: Optional[List[int]] = None, + grid_offset: float = 0., + grid_indexing: str = 'ij', ): super().__init__() self.dim = dim @@ -297,6 +316,8 @@ class RotaryEmbedding(nn.Module): self.in_pixels = in_pixels self.feat_shape = feat_shape self.ref_feat_shape = ref_feat_shape + self.grid_offset = grid_offset + self.grid_indexing = grid_indexing if feat_shape is None: # only cache bands @@ -328,6 +349,8 @@ class RotaryEmbedding(nn.Module): linear_bands=linear_bands, in_pixels=in_pixels, ref_feat_shape=self.ref_feat_shape, + grid_offset=self.grid_offset, + grid_indexing=self.grid_indexing, ) self.bands = None self.register_buffer( @@ -349,6 +372,9 @@ class RotaryEmbedding(nn.Module): shape, self.bands, in_pixels=self.in_pixels, + ref_feat_shape=self.ref_feat_shape, + grid_offset=self.grid_offset, + grid_indexing=self.grid_indexing, ) else: return self.pos_embed_sin, self.pos_embed_cos @@ -376,6 +402,8 @@ class RotaryEmbeddingCat(nn.Module): linear_bands: bool = False, feat_shape: Optional[List[int]] = None, ref_feat_shape: Optional[List[int]] = None, + grid_offset: float = 0., + grid_indexing: str = 'ij', ): super().__init__() self.dim = dim @@ -384,6 +412,8 @@ class RotaryEmbeddingCat(nn.Module): self.in_pixels = in_pixels self.feat_shape = feat_shape self.ref_feat_shape = ref_feat_shape + self.grid_offset = grid_offset + self.grid_indexing = grid_indexing if feat_shape is None: # only cache bands @@ -414,6 +444,8 @@ class RotaryEmbeddingCat(nn.Module): linear_bands=linear_bands, in_pixels=in_pixels, ref_feat_shape=self.ref_feat_shape, + grid_offset=self.grid_offset, + grid_indexing=self.grid_indexing, ) self.bands = None self.register_buffer( @@ -430,6 +462,8 @@ class RotaryEmbeddingCat(nn.Module): self.bands, in_pixels=self.in_pixels, ref_feat_shape=self.ref_feat_shape, + grid_offset=self.grid_offset, + grid_indexing=self.grid_indexing, ) return torch.cat(embeds, -1) elif self.pos_embed is not None: diff --git a/timm/models/eva.py b/timm/models/eva.py index 26c5278a..99a77bfd 100644 --- a/timm/models/eva.py +++ b/timm/models/eva.py @@ -25,6 +25,7 @@ Modifications by / Copyright 2023 Ross Wightman, original copyrights below # EVA models Copyright (c) 2022 BAAI-Vision # EVA02 models Copyright (c) 2023 BAAI-Vision import math +from functools import partial from typing import Callable, List, Optional, Tuple, Union import torch @@ -34,7 +35,7 @@ import torch.nn.functional as F from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, OPENAI_CLIP_MEAN, OPENAI_CLIP_STD from timm.layers import PatchEmbed, Mlp, GluMlp, SwiGLU, LayerNorm, DropPath, PatchDropout, RotaryEmbeddingCat, \ apply_rot_embed_cat, apply_keep_indices_nlc, trunc_normal_, resample_patch_embed, resample_abs_pos_embed, \ - to_2tuple, use_fused_attn + global_pool_nlc, to_2tuple, use_fused_attn, AttentionRope, AttentionPoolLatent from ._builder import build_model_with_cfg from ._features import feature_take_indices @@ -45,6 +46,8 @@ __all__ = ['Eva'] class EvaAttention(nn.Module): + """ EVA Attention with ROPE, no k-bias, and fused/unfused qkv options + """ fused_attn: torch.jit.Final[bool] def __init__( @@ -53,55 +56,64 @@ class EvaAttention(nn.Module): num_heads: int = 8, qkv_bias: bool = True, qkv_fused: bool = True, - num_prefix_tokens: int = 1, qkv_bias_separate: bool = False, + num_prefix_tokens: int = 1, attn_drop: float = 0., proj_drop: float = 0., attn_head_dim: Optional[int] = None, norm_layer: Optional[Callable] = None, + qk_norm: bool = False, + scale_norm: bool = True, ): """ - Args: - dim: - num_heads: - qkv_bias: - qkv_fused: - attn_drop: - proj_drop: - attn_head_dim: - norm_layer: + dim: Input dimension of the token embeddings + num_heads: Number of attention heads + qkv_bias: Whether to add a bias term to the query, key, and value projections + qkv_fused: Whether qkv projections are fused into one projection or separate + qkv_bias_separate: Whether to apply bias to qkv as a separate addition or part of F.linear() call + num_prefix_tokens: Number of reg/cls tokens at the beginning of the sequence that + should not have position embeddings applied + attn_drop: Dropout rate for attention weights + proj_drop: Dropout rate for the output projection + attn_head_dim: Dimension of each attention head (if None, computed as dim // num_heads) + norm_layer: Normalization layer constructor to use for QK and scale normalization + qk_norm: Enable normalization of query (Q) and key (K) vectors with norm_layer + scale_norm: Enable normalization (scaling) of attention output with norm_layer """ super().__init__() + if scale_norm or qk_norm: + assert norm_layer is not None, 'norm_layer must be provided if qk_norm or scale_norm is True' self.num_heads = num_heads head_dim = dim // num_heads if attn_head_dim is not None: head_dim = attn_head_dim - all_head_dim = head_dim * self.num_heads + attn_dim = head_dim * self.num_heads self.scale = head_dim ** -0.5 self.num_prefix_tokens = num_prefix_tokens self.fused_attn = use_fused_attn() self.qkv_bias_separate = qkv_bias_separate if qkv_fused: - self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False) + self.qkv = nn.Linear(dim, attn_dim * 3, bias=False) self.q_proj = self.k_proj = self.v_proj = None if qkv_bias: - self.q_bias = nn.Parameter(torch.zeros(all_head_dim)) - self.register_buffer('k_bias', torch.zeros(all_head_dim), persistent=False) - self.v_bias = nn.Parameter(torch.zeros(all_head_dim)) + self.q_bias = nn.Parameter(torch.zeros(attn_dim)) + self.register_buffer('k_bias', torch.zeros(attn_dim), persistent=False) + self.v_bias = nn.Parameter(torch.zeros(attn_dim)) else: self.q_bias = self.k_bias = self.v_bias = None else: - self.q_proj = nn.Linear(dim, all_head_dim, bias=qkv_bias) - self.k_proj = nn.Linear(dim, all_head_dim, bias=False) - self.v_proj = nn.Linear(dim, all_head_dim, bias=qkv_bias) + self.q_proj = nn.Linear(dim, attn_dim, bias=qkv_bias) + self.k_proj = nn.Linear(dim, attn_dim, bias=False) + self.v_proj = nn.Linear(dim, attn_dim, bias=qkv_bias) self.qkv = None self.q_bias = self.k_bias = self.v_bias = None - + self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() + self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() self.attn_drop = nn.Dropout(attn_drop) - self.norm = norm_layer(all_head_dim) if norm_layer is not None else nn.Identity() - self.proj = nn.Linear(all_head_dim, dim) + self.norm = norm_layer(attn_dim) if scale_norm else nn.Identity() + self.proj = nn.Linear(attn_dim, dim) self.proj_drop = nn.Dropout(proj_drop) def forward( @@ -110,6 +122,16 @@ class EvaAttention(nn.Module): rope: Optional[torch.Tensor] = None, attn_mask: Optional[torch.Tensor] = None, ): + """Forward pass for the attention module. + + Args: + x: Input tensor of shape (batch_size, sequence_length, embedding_dim) + rope: Rotary position embeddings tensor for position-aware attention + attn_mask: Optional attention mask to apply during attention computation + + Returns: + Tensor of shape (batch_size, sequence_length, embedding_dim) + """ B, N, C = x.shape if self.qkv is not None: @@ -129,6 +151,8 @@ class EvaAttention(nn.Module): k = self.k_proj(x).reshape(B, N, self.num_heads, -1).transpose(1, 2) v = self.v_proj(x).reshape(B, N, self.num_heads, -1).transpose(1, 2) + q, k = self.q_norm(q), self.k_norm(k) + if rope is not None: npt = self.num_prefix_tokens q = torch.cat([q[:, :, :npt, :], apply_rot_embed_cat(q[:, :, npt:, :], rope)], dim=2).type_as(v) @@ -172,6 +196,7 @@ class EvaBlock(nn.Module): scale_mlp: bool = False, scale_attn_inner: bool = False, num_prefix_tokens: int = 1, + attn_type: str = 'eva', proj_drop: float = 0., attn_drop: float = 0., drop_path: float = 0., @@ -180,28 +205,31 @@ class EvaBlock(nn.Module): norm_layer: Callable = LayerNorm, attn_head_dim: Optional[int] = None, ): - """ + """ Initialize the EVA transformer block. Args: - dim: - num_heads: - qkv_bias: - qkv_fused: - mlp_ratio: - swiglu_mlp: - scale_mlp: - scale_attn_inner: - proj_drop: - attn_drop: - drop_path: - init_values: - act_layer: - norm_layer: - attn_head_dim: + dim: Input dimension of the token embeddings + num_heads: Number of attention heads + qkv_bias: Whether to use bias terms in query, key, value projections + qkv_fused: Whether to use a single projection for query, key, value + mlp_ratio: Ratio of MLP hidden dimension to input dimension + swiglu_mlp: Whether to use SwiGLU activation in the MLP + scale_mlp: Whether to use normalization in the MLP + scale_attn_inner: Whether to use normalization within the attention mechanism + num_prefix_tokens: Number of tokens at the beginning of the sequence (class tokens, etc.) + attn_type: Type of attention module to use ('eva' or 'rope') + proj_drop: Dropout rate for projection layers + attn_drop: Dropout rate for attention matrix + drop_path: Stochastic depth rate + init_values: Initial value for LayerScale, None = no LayerScale + act_layer: Activation layer constructor + norm_layer: Normalization layer constructor + attn_head_dim: Dimension of each attention head (if None, computed as dim // num_heads) """ super().__init__() self.norm1 = norm_layer(dim) - self.attn = EvaAttention( + attn_cls = AttentionRope if attn_type == 'rope' else EvaAttention + self.attn = attn_cls( dim, num_heads=num_heads, qkv_bias=qkv_bias, @@ -210,7 +238,8 @@ class EvaBlock(nn.Module): attn_drop=attn_drop, proj_drop=proj_drop, attn_head_dim=attn_head_dim, - norm_layer=norm_layer if scale_attn_inner else None, + norm_layer=norm_layer, + scale_norm=scale_attn_inner, ) self.gamma_1 = nn.Parameter(init_values * torch.ones(dim)) if init_values is not None else None self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity() @@ -266,6 +295,7 @@ class EvaBlockPostNorm(nn.Module): qkv_bias: bool = True, qkv_fused: bool = True, mlp_ratio: float = 4., + attn_type: str = 'eva', swiglu_mlp: bool = False, scale_mlp: bool = False, scale_attn_inner: bool = False, @@ -278,27 +308,30 @@ class EvaBlockPostNorm(nn.Module): norm_layer: Callable = nn.LayerNorm, attn_head_dim: Optional[int] = None, ): - """ + """ Initialize the post-norm EVA transformer block. Args: - dim: - num_heads: - qkv_bias: - qkv_fused: - mlp_ratio: - swiglu_mlp: - scale_mlp: - scale_attn_inner: - proj_drop: - attn_drop: - drop_path: - init_values: - act_layer: - norm_layer: - attn_head_dim: + dim: Input dimension of the token embeddings + num_heads: Number of attention heads + qkv_bias: Whether to use bias terms in query, key, value projections + qkv_fused: Whether to use a single projection for query, key, value + mlp_ratio: Ratio of MLP hidden dimension to input dimension + swiglu_mlp: Whether to use SwiGLU activation in the MLP + scale_mlp: Whether to use normalization in the MLP + scale_attn_inner: Whether to use normalization within the attention mechanism + num_prefix_tokens: Number of tokens at the beginning of the sequence (class tokens, etc.) + attn_type: Type of attention module to use ('eva' or 'rope') + proj_drop: Dropout rate for projection layers + attn_drop: Dropout rate for attention matrix + drop_path: Stochastic depth rate + init_values: Initial value for LayerScale, None = no LayerScale (NOTE: ignored for post-norm block) + act_layer: Activation layer constructor + norm_layer: Normalization layer constructor + attn_head_dim: Dimension of each attention head (if None, computed as dim // num_heads) """ super().__init__() - self.attn = EvaAttention( + attn_cls = AttentionRope if attn_type == 'rope' else EvaAttention + self.attn = attn_cls( dim, num_heads=num_heads, qkv_bias=qkv_bias, @@ -307,7 +340,8 @@ class EvaBlockPostNorm(nn.Module): attn_drop=attn_drop, proj_drop=proj_drop, attn_head_dim=attn_head_dim, - norm_layer=norm_layer if scale_attn_inner else None, + norm_layer=norm_layer, + scale_norm=scale_attn_inner, ) self.norm1 = norm_layer(dim) self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity() @@ -373,6 +407,7 @@ class Eva(nn.Module): swiglu_mlp: bool = False, scale_mlp: bool = False, scale_attn_inner: bool = False, + attn_type: str = 'eva', drop_rate: float = 0., pos_drop_rate: float = 0., patch_drop_rate: float = 0., @@ -385,44 +420,64 @@ class Eva(nn.Module): num_reg_tokens: int = 0, use_abs_pos_emb: bool = True, use_rot_pos_emb: bool = False, + rope_grid_offset: float = 0., + rope_grid_indexing: str = 'ij', use_post_norm: bool = False, + use_pre_transformer_norm: bool = False, + use_post_transformer_norm: Optional[bool] = None, + use_fc_norm: Optional[bool] = None, + attn_pool_num_heads: Optional[int] = None, + attn_pool_mlp_ratio: Optional[float] = None, dynamic_img_size: bool = False, dynamic_img_pad: bool = False, ref_feat_shape: Optional[Union[Tuple[int, int], int]] = None, head_init_scale: float = 0.001, ): - """ + """Initialize the EVA Vision Transformer model. Args: - img_size: - patch_size: - in_chans: - num_classes: - global_pool: - embed_dim: - depth: - num_heads: - qkv_bias: - qkv_fused: - mlp_ratio: - swiglu_mlp: - scale_mlp: - scale_attn_inner: - drop_rate: - pos_drop_rate: - proj_drop_rate: - attn_drop_rate: - drop_path_rate: - norm_layer: - init_values: - class_token: - use_abs_pos_emb: - use_rot_pos_emb: - use_post_norm: - ref_feat_shape: - head_init_scale: + img_size: Input image size (single int for square, or tuple for rectangular) + patch_size: Patch size to divide image into tokens (single int for square, or tuple) + in_chans: Number of input image channels + num_classes: Number of classes (output dim) for classification head (final projection), 0 for pass-through + global_pool: Type of global pooling for final sequence ('avg', 'token', 'map', etc.) + embed_dim: Embedding dimension for tokens + depth: Number of transformer blocks + num_heads: Number of attention heads + qkv_bias: Enable bias for query, key, value projections + qkv_fused: Use a single projection for query, key, value + mlp_ratio: Ratio of mlp hidden dim to embedding dim + swiglu_mlp: Use SwiGLU activation in MLP + scale_mlp: Apply scaling normalization in MLP (normformer style) + scale_attn_inner: Apply scaling normalization inside attention + attn_type: Type of attention module to use + drop_rate: Dropout rate after final projection and pooling + pos_drop_rate: Dropout rate for positional embeddings + patch_drop_rate: Rate of dropping patches during training + proj_drop_rate: Dropout rate for projections + attn_drop_rate: Dropout rate for attention + drop_path_rate: Stochastic depth rate + norm_layer: Normalization layer constructor + init_values: Initial layer-scale values + class_token: Use class token + num_reg_tokens: Number of additional learnable 'register' tokens to add to the sequence + use_abs_pos_emb: Use absolute (learned) positional embeddings + use_rot_pos_emb: Use rotary position embeddings + rope_grid_offset: Offset for rotary position embedding grid + rope_grid_indexing: Indexing mode for rotary position embeddings ('ij' or 'xy') + use_post_norm: Use post-norm transformer block type + use_pre_transformer_norm: Use normalization layer before transformer blocks + use_post_transformer_norm: Use normalization layer after transformer blocks + use_fc_norm: Use normalization layer after pooling, before final classifier + attn_pool_num_heads: Number of heads in attention pooling + attn_pool_mlp_ratio: MLP ratio in attention pooling + dynamic_img_size: Support dynamic image sizes in forward pass + dynamic_img_pad: Apply dynamic padding for irregular image sizes + ref_feat_shape: Reference feature shape for rotary position embedding scale + head_init_scale: Initialization scale for classification head weights """ super().__init__() + assert global_pool in ('', 'avg', 'avgmax', 'max', 'token', 'map') self.num_classes = num_classes self.global_pool = global_pool self.num_features = self.head_hidden_size = self.embed_dim = embed_dim # for consistency with other models @@ -430,6 +485,17 @@ class Eva(nn.Module): self.dynamic_img_size = dynamic_img_size self.grad_checkpointing = False + # resolve norm / pool usage + activate_pre_norm = use_pre_transformer_norm + if use_fc_norm is not None: + activate_fc_norm = use_fc_norm # pass through if explicit + else: + activate_fc_norm = global_pool == 'avg' # default on if avg pool used + if use_post_transformer_norm is not None: + activate_post_norm = use_post_transformer_norm # pass through if explicit + else: + activate_post_norm = not activate_fc_norm # default on if fc_norm isn't active + embed_args = {} if dynamic_img_size: # flatten deferred until after pos embed @@ -440,6 +506,7 @@ class Eva(nn.Module): in_chans=in_chans, embed_dim=embed_dim, dynamic_img_pad=dynamic_img_pad, + bias=not use_pre_transformer_norm, **embed_args, ) num_patches = self.patch_embed.num_patches @@ -468,10 +535,14 @@ class Eva(nn.Module): in_pixels=False, feat_shape=None if dynamic_img_size else self.patch_embed.grid_size, ref_feat_shape=ref_feat_shape, + grid_offset=rope_grid_offset, + grid_indexing=rope_grid_indexing, ) else: self.rope = None + self.norm_pre = norm_layer(embed_dim) if activate_pre_norm else nn.Identity() + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule block_fn = EvaBlockPostNorm if use_post_norm else EvaBlock self.blocks = nn.ModuleList([ @@ -484,6 +555,7 @@ class Eva(nn.Module): swiglu_mlp=swiglu_mlp, scale_mlp=scale_mlp, scale_attn_inner=scale_attn_inner, + attn_type=attn_type, num_prefix_tokens=self.num_prefix_tokens, proj_drop=proj_drop_rate, attn_drop=attn_drop_rate, @@ -495,9 +567,21 @@ class Eva(nn.Module): self.feature_info = [ dict(module=f'blocks.{i}', num_chs=embed_dim, reduction=r) for i in range(depth)] - use_fc_norm = self.global_pool == 'avg' - self.norm = nn.Identity() if use_fc_norm else norm_layer(embed_dim) - self.fc_norm = norm_layer(embed_dim) if use_fc_norm else nn.Identity() + self.norm = norm_layer(embed_dim) if activate_post_norm else nn.Identity() + + if global_pool == 'map': + attn_pool_num_heads = attn_pool_num_heads or num_heads + attn_pool_mlp_ratio = attn_pool_mlp_ratio or mlp_ratio + self.attn_pool = AttentionPoolLatent( + self.embed_dim, + num_heads=attn_pool_num_heads, + mlp_ratio=attn_pool_mlp_ratio, + norm_layer=norm_layer, + act_layer=nn.GELU, + ) + else: + self.attn_pool = None + self.fc_norm = norm_layer(embed_dim) if activate_fc_norm else nn.Identity() self.head_drop = nn.Dropout(drop_rate) self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity() @@ -626,6 +710,7 @@ class Eva(nn.Module): B, _, height, width = x.shape x = self.patch_embed(x) x, rot_pos_embed = self._pos_embed(x) + x = self.norm_pre(x) if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript blocks = self.blocks else: @@ -668,13 +753,23 @@ class Eva(nn.Module): if prune_norm: self.norm = nn.Identity() if prune_head: + self.attn_pool = None self.fc_norm = nn.Identity() self.reset_classifier(0, '') return take_indices + def pool(self, x: torch.Tensor, pool_type: Optional[str] = None) -> torch.Tensor: + if self.attn_pool is not None: + x = self.attn_pool(x) + return x + pool_type = self.global_pool if pool_type is None else pool_type + x = global_pool_nlc(x, pool_type=pool_type, num_prefix_tokens=self.num_prefix_tokens) + return x + def forward_features(self, x): x = self.patch_embed(x) x, rot_pos_embed = self._pos_embed(x) + x = self.norm_pre(x) for blk in self.blocks: if self.grad_checkpointing and not torch.jit.is_scripting(): x = checkpoint(blk, x, rope=rot_pos_embed) @@ -684,8 +779,7 @@ class Eva(nn.Module): return x def forward_head(self, x, pre_logits: bool = False): - if self.global_pool: - x = x[:, self.num_prefix_tokens:].mean(dim=1) if self.global_pool == 'avg' else x[:, 0] + x = self.pool(x) x = self.fc_norm(x) x = self.head_drop(x) return x if pre_logits else self.head(x) @@ -696,6 +790,67 @@ class Eva(nn.Module): return x +def _convert_pe( + state_dict, + model, + prefix: str = 'visual.', +): + """ Convert Perception Encoder weights """ + state_dict = state_dict.get('model', state_dict) + state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()} + + out_dict = {} + swaps = [ + ('conv1', 'patch_embed.proj'), + ('positional_embedding', 'pos_embed'), + ('transformer.resblocks.', 'blocks.'), + ('ln_pre', 'norm_pre'), + ('ln_post', 'norm'), + ('ln_', 'norm'), + ('ls_1.gamma', 'gamma_1'), + ('ls_2.gamma', 'gamma_2'), + ('in_proj_', 'qkv.'), + ('out_proj', 'proj'), + ('mlp.c_fc', 'mlp.fc1'), + ('mlp.c_proj', 'mlp.fc2'), + ] + len_prefix = len(prefix) + for k, v in state_dict.items(): + if prefix: + if not k.startswith(prefix): + continue + k = k[len_prefix:] + + for sp in swaps: + k = k.replace(sp[0], sp[1]) + + if k.startswith('attn_pool'): + k = k.replace('attn_pool.attn', 'attn_pool') + k = k.replace('attn_pool.layernorm', 'attn_pool.norm') + k = k.replace('attn_pool.probe', 'attn_pool.latent') + if k.startswith('attn_pool.qkv'): + dim = v.shape[0] // 3 + if k.endswith('weight'): + out_dict['attn_pool.q.weight'] = v[:dim] + out_dict['attn_pool.kv.weight'] = v[dim:] + elif k.endswith('bias'): + out_dict['attn_pool.q.bias'] = v[:dim] + out_dict['attn_pool.kv.bias'] = v[dim:] + continue + elif k == 'proj': + k = 'head.weight' + v = v.transpose(0, 1) + out_dict['head.bias'] = torch.zeros(v.shape[0]) + elif k == 'class_embedding': + k = 'cls_token' + v = v.unsqueeze(0).unsqueeze(1) + elif k == 'pos_embed': + v = v.unsqueeze(0) + out_dict[k] = v + + return out_dict + + def checkpoint_filter_fn( state_dict, model, @@ -708,6 +863,13 @@ def checkpoint_filter_fn( state_dict = state_dict.get('model', state_dict) state_dict = state_dict.get('module', state_dict) state_dict = state_dict.get('state_dict', state_dict) + + # Loading Meta PE (Perception Encoder) weights + if 'visual.conv1.weight' in state_dict: + return _convert_pe(state_dict, model) + elif 'conv1.weight' in state_dict: + return _convert_pe(state_dict, model, prefix='') + # prefix for loading OpenCLIP compatible weights if 'visual.trunk.pos_embed' in state_dict: prefix = 'visual.trunk.' @@ -721,10 +883,9 @@ def checkpoint_filter_fn( len_prefix = len(prefix) for k, v in state_dict.items(): if prefix: - if k.startswith(prefix): - k = k[len_prefix:] - else: + if not k.startswith(prefix): continue + k = k[len_prefix:] if 'rope' in k: # fixed embedding no need to load buffer from checkpoint @@ -797,6 +958,17 @@ def _cfg(url='', **kwargs): } +def _pe_cfg(url='', **kwargs): + return { + 'url': url, + 'num_classes': 0, 'input_size': (3, 224, 224), 'pool_size': None, + 'crop_pct': 1.0, 'interpolation': 'bicubic', 'fixed_input_size': True, + 'mean': (0.5, 0.5, 0.5), 'std': (0.5, 0.5, 0.5), + 'first_conv': 'patch_embed.proj', 'classifier': 'head', + 'license': 'custom', **kwargs + } + + default_cfgs = generate_default_cfgs({ # EVA 01 CLIP fine-tuned on imagenet-1k @@ -984,6 +1156,49 @@ default_cfgs = generate_default_cfgs({ input_size=(3, 256, 256), crop_pct=0.95, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5) ), + + # Perception Encoder weights + 'vit_pe_core_base_patch16_224': _pe_cfg( + #hf_hub_id='facebook/pe_core_base_patch16_224_timm', + hf_hub_id='facebook/PE-Core-B16-224', + hf_hub_filename='PE-Core-B16-224.pt', + input_size=(3, 224, 224), + num_classes=1024, # output proj dim + ), + 'vit_pe_core_large_patch14_336': _pe_cfg( + hf_hub_id='facebook/PE-Core-L14-336', + hf_hub_filename='PE-Core-L14-336.pt', + input_size=(3, 336, 336), + num_classes=1024, # output proj dim + ), + 'vit_pe_core_gigantic_patch14_448': _pe_cfg( + #hf_hub_id='timm/', + hf_hub_id='facebook/PE-Core-G14-448', + hf_hub_filename='PE-Core-G14-448.pt', + input_size=(3, 448, 448), + num_classes=1280, # output proj dim + ), + 'vit_pe_lang_large_patch14_448': _pe_cfg( + #hf_hub_id='timm/', + hf_hub_id='facebook/PE-Lang-L14-448', + hf_hub_filename='PE-Lang-L14-448.pt', + input_size=(3, 448, 448), + num_classes=0, + ), + 'vit_pe_lang_gigantic_patch14_448': _pe_cfg( + #hf_hub_id='timm/', + hf_hub_id='facebook/PE-Lang-G14-448', + hf_hub_filename='PE-Lang-G14-448.pt', + input_size=(3, 448, 448), + num_classes=0, + ), + 'vit_pe_spatial_gigantic_patch14_448': _pe_cfg( + #hf_hub_id='timm/', + hf_hub_id='facebook/PE-Spatial-G14-448', + hf_hub_filename='PE-Spatial-G14-448.pt', + input_size=(3, 448, 448), + num_classes=0, + ), }) @@ -1330,3 +1545,142 @@ def vit_base_patch16_rope_reg1_gap_256(pretrained=False, **kwargs) -> Eva: ) model = _create_eva('vit_base_patch16_rope_reg1_gap_256', pretrained=pretrained, **dict(model_args, **kwargs)) return model + + +@register_model +def vit_pe_core_base_patch16_224(pretrained=False, **kwargs): + model_args = dict( + patch_size=16, + embed_dim=768, + depth=12, + num_heads=12, + mlp_ratio=4.0, + global_pool='map', + attn_type='rope', + use_pre_transformer_norm=True, + use_rot_pos_emb=True, + ref_feat_shape=(14, 14), + rope_grid_offset=1., + rope_grid_indexing='xy', + attn_pool_num_heads=8, + attn_pool_mlp_ratio=4., + norm_layer=partial(LayerNorm, eps=1e-5), + #dynamic_img_size=True + ) + return _create_eva('vit_pe_core_base_patch16_224', pretrained=pretrained, **dict(model_args, **kwargs)) + + +@register_model +def vit_pe_core_large_patch14_336(pretrained=False, **kwargs): + model_args = dict( + patch_size=14, + embed_dim=1024, + depth=24, + num_heads=16, + mlp_ratio=4.0, + global_pool='map', + attn_type='rope', + use_pre_transformer_norm=True, + use_rot_pos_emb=True, + ref_feat_shape=(24, 24), + rope_grid_offset=1., + rope_grid_indexing='xy', + attn_pool_num_heads=8, + attn_pool_mlp_ratio=4., + norm_layer=partial(LayerNorm, eps=1e-5), + #dynamic_img_size=True, + ) + return _create_eva('vit_pe_core_large_patch14_336', pretrained=pretrained, **dict(model_args, **kwargs)) + + +@register_model +def vit_pe_core_gigantic_patch14_448(pretrained=False, **kwargs): + model_args = dict( + patch_size=14, + embed_dim=1536, + depth=50, + num_heads=16, + mlp_ratio=8960 / 1536, + global_pool='map', + attn_type='rope', + class_token=False, + use_pre_transformer_norm=True, + use_rot_pos_emb=True, + ref_feat_shape=(32, 32), + rope_grid_indexing='xy', + attn_pool_num_heads=8, + attn_pool_mlp_ratio=4., + norm_layer=partial(LayerNorm, eps=1e-5), + #dynamic_img_size=True, + ) + return _create_eva('vit_pe_core_gigantic_patch14_448', pretrained=pretrained, **dict(model_args, **kwargs)) + + +@register_model +def vit_pe_lang_large_patch14_448(pretrained=False, **kwargs): + model_args = dict( + patch_size=14, + embed_dim=1024, + depth=23, + num_heads=16, + mlp_ratio=4.0, + attn_type='rope', + class_token=True, + use_rot_pos_emb=True, + ref_feat_shape=(32, 32), + rope_grid_offset=1., + rope_grid_indexing='xy', + use_pre_transformer_norm=True, + use_post_transformer_norm=False, + use_fc_norm=False, # explicitly disable + init_values=0.1, + norm_layer=partial(LayerNorm, eps=1e-5), + #dynamic_img_size=True, + ) + return _create_eva('vit_pe_lang_large_patch14_448', pretrained=pretrained, **dict(model_args, **kwargs)) + + +@register_model +def vit_pe_lang_gigantic_patch14_448(pretrained=False, **kwargs): + model_args = dict( + patch_size=14, + embed_dim=1536, + depth=47, + num_heads=16, + mlp_ratio=8960 / 1536, + attn_type='rope', + class_token=False, + use_rot_pos_emb=True, + ref_feat_shape=(32, 32), + rope_grid_indexing='xy', + use_pre_transformer_norm=True, + use_post_transformer_norm=False, + use_fc_norm=False, # explicitly disable + init_values=0.1, + norm_layer=partial(LayerNorm, eps=1e-5), + #dynamic_img_size=True, + ) + return _create_eva('vit_pe_lang_gigantic_patch14_448', pretrained=pretrained, **dict(model_args, **kwargs)) + + +@register_model +def vit_pe_spatial_gigantic_patch14_448(pretrained=False, **kwargs): + model_args = dict( + patch_size=14, + embed_dim=1536, + depth=50, + num_heads=16, + mlp_ratio=8960 / 1536, + attn_type='rope', + class_token=False, + use_rot_pos_emb=True, + ref_feat_shape=(32, 32), + rope_grid_indexing='xy', + use_pre_transformer_norm=True, + use_post_transformer_norm=False, + init_values=0.1, + norm_layer=partial(LayerNorm, eps=1e-5), + #dynamic_img_size=True, + ) + return _create_eva('vit_pe_spatial_gigantic_patch14_448', pretrained=pretrained, **dict(model_args, **kwargs)) +