diff --git a/timm/layers/attention_pool.py b/timm/layers/attention_pool.py new file mode 100644 index 00000000..41e404d2 --- /dev/null +++ b/timm/layers/attention_pool.py @@ -0,0 +1,103 @@ +from typing import Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .config import use_fused_attn +from .mlp import Mlp +from .weight_init import trunc_normal_tf_ + + +class AttentionPoolLatent(nn.Module): + """ Attention pooling w/ latent query + """ + fused_attn: torch.jit.Final[bool] + + def __init__( + self, + in_features: int, + out_features: int = None, + embed_dim: int = None, + num_heads: int = 8, + mlp_ratio: float = 4.0, + qkv_bias: bool = True, + qk_norm: bool = False, + latent_len: int = 1, + latent_dim: int = None, + pos_embed: str = '', + pool_type: str = 'token', + norm_layer: Optional[nn.Module] = None, + drop: float = 0.0, + ): + super().__init__() + embed_dim = embed_dim or in_features + out_features = out_features or in_features + assert embed_dim % num_heads == 0 + self.num_heads = num_heads + self.head_dim = embed_dim // num_heads + self.scale = self.head_dim ** -0.5 + self.pool = pool_type + self.fused_attn = use_fused_attn() + + if pos_embed == 'abs': + spatial_len = self.feat_size + self.pos_embed = nn.Parameter(torch.zeros(spatial_len, in_features)) + else: + self.pos_embed = None + + self.latent_dim = latent_dim or embed_dim + self.latent_len = latent_len + self.latent = nn.Parameter(torch.zeros(1, self.latent_len, embed_dim)) + + self.q = nn.Linear(embed_dim, embed_dim, bias=qkv_bias) + self.kv = nn.Linear(embed_dim, embed_dim * 2, bias=qkv_bias) + self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() + self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() + self.proj = nn.Linear(embed_dim, embed_dim) + self.proj_drop = nn.Dropout(drop) + + self.norm = norm_layer(out_features) if norm_layer is not None else nn.Identity() + self.mlp = Mlp(embed_dim, int(embed_dim * mlp_ratio)) + + self.init_weights() + + def init_weights(self): + if self.pos_embed is not None: + trunc_normal_tf_(self.pos_embed, std=self.pos_embed.shape[1] ** -0.5) + trunc_normal_tf_(self.latent, std=self.latent_dim ** -0.5) + + def forward(self, x): + B, N, C = x.shape + + if self.pos_embed is not None: + # FIXME interpolate + x = x + self.pos_embed.unsqueeze(0).to(x.dtype) + + q_latent = self.latent.expand(B, -1, -1) + q = self.q(q_latent).reshape(B, self.latent_len, self.num_heads, self.head_dim).transpose(1, 2) + + kv = self.kv(x).reshape(B, N, 2, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) + k, v = kv.unbind(0) + + q, k = self.q_norm(q), self.k_norm(k) + + if self.fused_attn: + x = F.scaled_dot_product_attention(q, k, v) + else: + q = q * self.scale + attn = q @ k.transpose(-2, -1) + attn = attn.softmax(dim=-1) + x = attn @ v + x = x.transpose(1, 2).reshape(B, self.latent_len, C) + x = self.proj(x) + x = self.proj_drop(x) + + x = x + self.mlp(self.norm(x)) + + # optional pool if latent seq_len > 1 and pooled output is desired + if self.pool == 'token': + x = x[:, 0] + elif self.pool == 'avg': + x = x.mean(1) + return x \ No newline at end of file diff --git a/timm/models/_hub.py b/timm/models/_hub.py index e2152d21..720a5091 100644 --- a/timm/models/_hub.py +++ b/timm/models/_hub.py @@ -376,7 +376,7 @@ def _get_safe_alternatives(filename: str) -> Iterable[str]: """ if filename == HF_WEIGHTS_NAME: yield HF_SAFE_WEIGHTS_NAME - # if filename == HF_OPEN_CLIP_WEIGHTS_NAME: # FIXME tracking safetensors yet - # yield HF_OPEN_CLIP_SAFE_WEIGHTS_NAME + if filename == HF_OPEN_CLIP_WEIGHTS_NAME: + yield HF_OPEN_CLIP_SAFE_WEIGHTS_NAME if filename not in (HF_WEIGHTS_NAME, HF_OPEN_CLIP_WEIGHTS_NAME) and filename.endswith(".bin"): yield filename[:-4] + ".safetensors" diff --git a/timm/models/vision_transformer.py b/timm/models/vision_transformer.py index ff753679..bd9158aa 100644 --- a/timm/models/vision_transformer.py +++ b/timm/models/vision_transformer.py @@ -37,8 +37,8 @@ from torch.jit import Final from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD, \ OPENAI_CLIP_MEAN, OPENAI_CLIP_STD -from timm.layers import PatchEmbed, Mlp, DropPath, trunc_normal_, lecun_normal_, resample_patch_embed, \ - resample_abs_pos_embed, RmsNorm, PatchDropout, use_fused_attn, SwiGLUPacked +from timm.layers import PatchEmbed, Mlp, DropPath, AttentionPoolLatent, RmsNorm, PatchDropout, SwiGLUPacked, \ + trunc_normal_, lecun_normal_, resample_patch_embed, resample_abs_pos_embed, use_fused_attn from ._builder import build_model_with_cfg from ._manipulate import named_apply, checkpoint_seq, adapt_input_conv from ._registry import generate_default_cfgs, register_model, register_model_deprecations @@ -377,95 +377,6 @@ class ParallelThingsBlock(nn.Module): return self._forward(x) -class AttentionPoolLatent(nn.Module): - """ Attention pooling w/ latent query - """ - def __init__( - self, - in_features: int, - out_features: int = None, - embed_dim: int = None, - num_heads: int = 8, - mlp_ratio: float = 4.0, - qkv_bias: bool = True, - qk_norm: bool = False, - latent_len: int = 1, - latent_dim: int = None, - pos_embed: str = '', - pool_type: str = 'token', - norm_layer: Optional[nn.Module] = None, - drop: float = 0.0, - ): - super().__init__() - embed_dim = embed_dim or in_features - out_features = out_features or in_features - assert embed_dim % num_heads == 0 - self.num_heads = num_heads - self.head_dim = embed_dim // num_heads - self.scale = self.head_dim ** -0.5 - self.pool = pool_type - self.fused_attn = use_fused_attn() - - if pos_embed == 'abs': - spatial_len = self.feat_size - self.pos_embed = nn.Parameter(torch.zeros(spatial_len, in_features)) - else: - self.pos_embed = None - - self.latent_dim = latent_dim or embed_dim - self.latent_len = latent_len - self.latent = nn.Parameter(torch.zeros(1, self.latent_len, embed_dim)) - - self.q = nn.Linear(embed_dim, embed_dim, bias=qkv_bias) - self.kv = nn.Linear(embed_dim, embed_dim * 2, bias=qkv_bias) - self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() - self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() - self.proj = nn.Linear(embed_dim, embed_dim) - self.proj_drop = nn.Dropout(drop) - - self.norm = norm_layer(out_features) if norm_layer is not None else nn.Identity() - self.mlp = Mlp(embed_dim, int(embed_dim * mlp_ratio)) - - def init_weights(self): - if self.pos_embed is not None: - trunc_normal_(self.pos_embed, std=self.pos_embed.shape[1] ** -0.5) - - def forward(self, x): - B, N, C = x.shape - - if self.pos_embed is not None: - # FIXME interpolate - x = x + self.pos_embed.unsqueeze(0).to(x.dtype) - - q_latent = self.latent.expand(B, -1, -1) - q = self.q(q_latent).reshape(B, self.latent_len, self.num_heads, self.head_dim).transpose(1, 2) - - kv = self.kv(x).reshape(B, N, 2, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) - k, v = kv.unbind(0) - - q, k = self.q_norm(q), self.k_norm(k) - - if self.fused_attn: - x = F.scaled_dot_product_attention(q, k, v) - else: - q = q * self.scale - attn = q @ k.transpose(-2, -1) - attn = attn.softmax(dim=-1) - x = attn @ v - x = x.transpose(1, 2).reshape(B, self.latent_len, C) - x = self.proj(x) - x = self.proj_drop(x) - - x = x + self.mlp(self.norm(x)) - - # optional pool if latent seq_len > 1 and pooled output is desired - if self.pool == 'token': - x = x[:, 0] - elif self.pool == 'avg': - x = x.mean(1) - return x - - class VisionTransformer(nn.Module): """ Vision Transformer @@ -1072,6 +983,12 @@ def checkpoint_filter_fn( if "encoder" in state_dict: state_dict = _convert_ijepa(state_dict, model) + if 'visual.trunk.pos_embed' in state_dict: + # convert an OpenCLIP model with timm vision encoder + prefix = 'visual.trunk.' + state_dict = {k[len(prefix):]: v for k, v in state_dict.items() if k.startswith(prefix)} + # FIXME remap final nn.Linear if it exists outside of the timm .trunk (ie in visual.head.proj) + for k, v in state_dict.items(): if 'patch_embed.proj.weight' in k: O, I, H, W = model.patch_embed.proj.weight.shape @@ -1634,48 +1551,42 @@ default_cfgs = generate_default_cfgs({ license='cc-by-nc-4.0', mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0), - 'vit_base_patch16_siglip_224': _cfg( - file='/data/n/temp/siglip/webli_en_b16_224_63724782.npz', - custom_load=True, - # hf_hub_id='timm/', + 'vit_base_patch16_siglip_224.webli': _cfg( + hf_hub_id='timm/ViT-B-16-SigLIP', + hf_hub_filename='open_clip_pytorch_model.bin', num_classes=0), - 'vit_base_patch16_siglip_256': _cfg( - file='/data/n/temp/siglip/webli_en_b16_256_60500360.npz', - custom_load=True, + 'vit_base_patch16_siglip_256.webli': _cfg( + hf_hub_id='timm/ViT-B-16-SigLIP-256', + hf_hub_filename='open_clip_pytorch_model.bin', input_size=(3, 256, 256), - # hf_hub_id='timm/', num_classes=0), - 'vit_base_patch16_siglip_384': _cfg( - file='', - custom_load=True, + 'vit_base_patch16_siglip_384.webli': _cfg( + hf_hub_id='timm/ViT-B-16-SigLIP-384', + hf_hub_filename='open_clip_pytorch_model.bin', input_size=(3, 384, 384), - # hf_hub_id='timm/', num_classes=0), - 'vit_base_patch16_siglip_512': _cfg( - file='', - custom_load=True, + 'vit_base_patch16_siglip_512.webli': _cfg( + hf_hub_id='timm/ViT-B-16-SigLIP-512', + hf_hub_filename='open_clip_pytorch_model.bin', input_size=(3, 512, 512), - # hf_hub_id='timm/', num_classes=0), - 'vit_large_patch16_siglip_256': _cfg( - custom_load=True, + 'vit_large_patch16_siglip_256.webli': _cfg( + hf_hub_id='timm/ViT-L-16-SigLIP-256', + hf_hub_filename='open_clip_pytorch_model.bin', input_size=(3, 256, 256), - # hf_hub_id='timm/', num_classes=0), - 'vit_large_patch16_siglip_384': _cfg( - custom_load=True, + 'vit_large_patch16_siglip_384.webli': _cfg( + hf_hub_id='timm/ViT-L-16-SigLIP-384', + hf_hub_filename='open_clip_pytorch_model.bin', input_size=(3, 384, 384), - # hf_hub_id='timm/', num_classes=0), - 'vit_so400m_patch14_siglip_224': _cfg( - # file='/data/n/temp/siglip/webli_en_b16_256_60500360.npz', - custom_load=True, - # hf_hub_id='timm/', + 'vit_so400m_patch14_siglip_224.webli': _cfg( + hf_hub_id='timm/ViT-SO400M-14-SigLIP', + hf_hub_filename='open_clip_pytorch_model.bin', num_classes=0), - 'vit_so400m_patch14_siglip_384': _cfg( - #file='/data/n/temp/siglip/webli_en_b16_256_60500360.npz', - custom_load=True, - # hf_hub_id='timm/', + 'vit_so400m_patch14_siglip_384.webli': _cfg( + hf_hub_id='timm/ViT-SO400M-14-SigLIP-384', + hf_hub_filename='open_clip_pytorch_model.bin', input_size=(3, 384, 384), num_classes=0), })