mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Add SigLIP weights
This commit is contained in:
parent
42daa3b497
commit
71365165a2
103
timm/layers/attention_pool.py
Normal file
103
timm/layers/attention_pool.py
Normal file
@ -0,0 +1,103 @@
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from .config import use_fused_attn
|
||||
from .mlp import Mlp
|
||||
from .weight_init import trunc_normal_tf_
|
||||
|
||||
|
||||
class AttentionPoolLatent(nn.Module):
|
||||
""" Attention pooling w/ latent query
|
||||
"""
|
||||
fused_attn: torch.jit.Final[bool]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_features: int,
|
||||
out_features: int = None,
|
||||
embed_dim: int = None,
|
||||
num_heads: int = 8,
|
||||
mlp_ratio: float = 4.0,
|
||||
qkv_bias: bool = True,
|
||||
qk_norm: bool = False,
|
||||
latent_len: int = 1,
|
||||
latent_dim: int = None,
|
||||
pos_embed: str = '',
|
||||
pool_type: str = 'token',
|
||||
norm_layer: Optional[nn.Module] = None,
|
||||
drop: float = 0.0,
|
||||
):
|
||||
super().__init__()
|
||||
embed_dim = embed_dim or in_features
|
||||
out_features = out_features or in_features
|
||||
assert embed_dim % num_heads == 0
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = embed_dim // num_heads
|
||||
self.scale = self.head_dim ** -0.5
|
||||
self.pool = pool_type
|
||||
self.fused_attn = use_fused_attn()
|
||||
|
||||
if pos_embed == 'abs':
|
||||
spatial_len = self.feat_size
|
||||
self.pos_embed = nn.Parameter(torch.zeros(spatial_len, in_features))
|
||||
else:
|
||||
self.pos_embed = None
|
||||
|
||||
self.latent_dim = latent_dim or embed_dim
|
||||
self.latent_len = latent_len
|
||||
self.latent = nn.Parameter(torch.zeros(1, self.latent_len, embed_dim))
|
||||
|
||||
self.q = nn.Linear(embed_dim, embed_dim, bias=qkv_bias)
|
||||
self.kv = nn.Linear(embed_dim, embed_dim * 2, bias=qkv_bias)
|
||||
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
||||
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
||||
self.proj = nn.Linear(embed_dim, embed_dim)
|
||||
self.proj_drop = nn.Dropout(drop)
|
||||
|
||||
self.norm = norm_layer(out_features) if norm_layer is not None else nn.Identity()
|
||||
self.mlp = Mlp(embed_dim, int(embed_dim * mlp_ratio))
|
||||
|
||||
self.init_weights()
|
||||
|
||||
def init_weights(self):
|
||||
if self.pos_embed is not None:
|
||||
trunc_normal_tf_(self.pos_embed, std=self.pos_embed.shape[1] ** -0.5)
|
||||
trunc_normal_tf_(self.latent, std=self.latent_dim ** -0.5)
|
||||
|
||||
def forward(self, x):
|
||||
B, N, C = x.shape
|
||||
|
||||
if self.pos_embed is not None:
|
||||
# FIXME interpolate
|
||||
x = x + self.pos_embed.unsqueeze(0).to(x.dtype)
|
||||
|
||||
q_latent = self.latent.expand(B, -1, -1)
|
||||
q = self.q(q_latent).reshape(B, self.latent_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
kv = self.kv(x).reshape(B, N, 2, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
|
||||
k, v = kv.unbind(0)
|
||||
|
||||
q, k = self.q_norm(q), self.k_norm(k)
|
||||
|
||||
if self.fused_attn:
|
||||
x = F.scaled_dot_product_attention(q, k, v)
|
||||
else:
|
||||
q = q * self.scale
|
||||
attn = q @ k.transpose(-2, -1)
|
||||
attn = attn.softmax(dim=-1)
|
||||
x = attn @ v
|
||||
x = x.transpose(1, 2).reshape(B, self.latent_len, C)
|
||||
x = self.proj(x)
|
||||
x = self.proj_drop(x)
|
||||
|
||||
x = x + self.mlp(self.norm(x))
|
||||
|
||||
# optional pool if latent seq_len > 1 and pooled output is desired
|
||||
if self.pool == 'token':
|
||||
x = x[:, 0]
|
||||
elif self.pool == 'avg':
|
||||
x = x.mean(1)
|
||||
return x
|
@ -376,7 +376,7 @@ def _get_safe_alternatives(filename: str) -> Iterable[str]:
|
||||
"""
|
||||
if filename == HF_WEIGHTS_NAME:
|
||||
yield HF_SAFE_WEIGHTS_NAME
|
||||
# if filename == HF_OPEN_CLIP_WEIGHTS_NAME: # FIXME tracking safetensors yet
|
||||
# yield HF_OPEN_CLIP_SAFE_WEIGHTS_NAME
|
||||
if filename == HF_OPEN_CLIP_WEIGHTS_NAME:
|
||||
yield HF_OPEN_CLIP_SAFE_WEIGHTS_NAME
|
||||
if filename not in (HF_WEIGHTS_NAME, HF_OPEN_CLIP_WEIGHTS_NAME) and filename.endswith(".bin"):
|
||||
yield filename[:-4] + ".safetensors"
|
||||
|
@ -37,8 +37,8 @@ from torch.jit import Final
|
||||
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD, \
|
||||
OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
|
||||
from timm.layers import PatchEmbed, Mlp, DropPath, trunc_normal_, lecun_normal_, resample_patch_embed, \
|
||||
resample_abs_pos_embed, RmsNorm, PatchDropout, use_fused_attn, SwiGLUPacked
|
||||
from timm.layers import PatchEmbed, Mlp, DropPath, AttentionPoolLatent, RmsNorm, PatchDropout, SwiGLUPacked, \
|
||||
trunc_normal_, lecun_normal_, resample_patch_embed, resample_abs_pos_embed, use_fused_attn
|
||||
from ._builder import build_model_with_cfg
|
||||
from ._manipulate import named_apply, checkpoint_seq, adapt_input_conv
|
||||
from ._registry import generate_default_cfgs, register_model, register_model_deprecations
|
||||
@ -377,95 +377,6 @@ class ParallelThingsBlock(nn.Module):
|
||||
return self._forward(x)
|
||||
|
||||
|
||||
class AttentionPoolLatent(nn.Module):
|
||||
""" Attention pooling w/ latent query
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
in_features: int,
|
||||
out_features: int = None,
|
||||
embed_dim: int = None,
|
||||
num_heads: int = 8,
|
||||
mlp_ratio: float = 4.0,
|
||||
qkv_bias: bool = True,
|
||||
qk_norm: bool = False,
|
||||
latent_len: int = 1,
|
||||
latent_dim: int = None,
|
||||
pos_embed: str = '',
|
||||
pool_type: str = 'token',
|
||||
norm_layer: Optional[nn.Module] = None,
|
||||
drop: float = 0.0,
|
||||
):
|
||||
super().__init__()
|
||||
embed_dim = embed_dim or in_features
|
||||
out_features = out_features or in_features
|
||||
assert embed_dim % num_heads == 0
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = embed_dim // num_heads
|
||||
self.scale = self.head_dim ** -0.5
|
||||
self.pool = pool_type
|
||||
self.fused_attn = use_fused_attn()
|
||||
|
||||
if pos_embed == 'abs':
|
||||
spatial_len = self.feat_size
|
||||
self.pos_embed = nn.Parameter(torch.zeros(spatial_len, in_features))
|
||||
else:
|
||||
self.pos_embed = None
|
||||
|
||||
self.latent_dim = latent_dim or embed_dim
|
||||
self.latent_len = latent_len
|
||||
self.latent = nn.Parameter(torch.zeros(1, self.latent_len, embed_dim))
|
||||
|
||||
self.q = nn.Linear(embed_dim, embed_dim, bias=qkv_bias)
|
||||
self.kv = nn.Linear(embed_dim, embed_dim * 2, bias=qkv_bias)
|
||||
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
||||
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
||||
self.proj = nn.Linear(embed_dim, embed_dim)
|
||||
self.proj_drop = nn.Dropout(drop)
|
||||
|
||||
self.norm = norm_layer(out_features) if norm_layer is not None else nn.Identity()
|
||||
self.mlp = Mlp(embed_dim, int(embed_dim * mlp_ratio))
|
||||
|
||||
def init_weights(self):
|
||||
if self.pos_embed is not None:
|
||||
trunc_normal_(self.pos_embed, std=self.pos_embed.shape[1] ** -0.5)
|
||||
|
||||
def forward(self, x):
|
||||
B, N, C = x.shape
|
||||
|
||||
if self.pos_embed is not None:
|
||||
# FIXME interpolate
|
||||
x = x + self.pos_embed.unsqueeze(0).to(x.dtype)
|
||||
|
||||
q_latent = self.latent.expand(B, -1, -1)
|
||||
q = self.q(q_latent).reshape(B, self.latent_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
kv = self.kv(x).reshape(B, N, 2, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
|
||||
k, v = kv.unbind(0)
|
||||
|
||||
q, k = self.q_norm(q), self.k_norm(k)
|
||||
|
||||
if self.fused_attn:
|
||||
x = F.scaled_dot_product_attention(q, k, v)
|
||||
else:
|
||||
q = q * self.scale
|
||||
attn = q @ k.transpose(-2, -1)
|
||||
attn = attn.softmax(dim=-1)
|
||||
x = attn @ v
|
||||
x = x.transpose(1, 2).reshape(B, self.latent_len, C)
|
||||
x = self.proj(x)
|
||||
x = self.proj_drop(x)
|
||||
|
||||
x = x + self.mlp(self.norm(x))
|
||||
|
||||
# optional pool if latent seq_len > 1 and pooled output is desired
|
||||
if self.pool == 'token':
|
||||
x = x[:, 0]
|
||||
elif self.pool == 'avg':
|
||||
x = x.mean(1)
|
||||
return x
|
||||
|
||||
|
||||
class VisionTransformer(nn.Module):
|
||||
""" Vision Transformer
|
||||
|
||||
@ -1072,6 +983,12 @@ def checkpoint_filter_fn(
|
||||
if "encoder" in state_dict:
|
||||
state_dict = _convert_ijepa(state_dict, model)
|
||||
|
||||
if 'visual.trunk.pos_embed' in state_dict:
|
||||
# convert an OpenCLIP model with timm vision encoder
|
||||
prefix = 'visual.trunk.'
|
||||
state_dict = {k[len(prefix):]: v for k, v in state_dict.items() if k.startswith(prefix)}
|
||||
# FIXME remap final nn.Linear if it exists outside of the timm .trunk (ie in visual.head.proj)
|
||||
|
||||
for k, v in state_dict.items():
|
||||
if 'patch_embed.proj.weight' in k:
|
||||
O, I, H, W = model.patch_embed.proj.weight.shape
|
||||
@ -1634,48 +1551,42 @@ default_cfgs = generate_default_cfgs({
|
||||
license='cc-by-nc-4.0',
|
||||
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0),
|
||||
|
||||
'vit_base_patch16_siglip_224': _cfg(
|
||||
file='/data/n/temp/siglip/webli_en_b16_224_63724782.npz',
|
||||
custom_load=True,
|
||||
# hf_hub_id='timm/',
|
||||
'vit_base_patch16_siglip_224.webli': _cfg(
|
||||
hf_hub_id='timm/ViT-B-16-SigLIP',
|
||||
hf_hub_filename='open_clip_pytorch_model.bin',
|
||||
num_classes=0),
|
||||
'vit_base_patch16_siglip_256': _cfg(
|
||||
file='/data/n/temp/siglip/webli_en_b16_256_60500360.npz',
|
||||
custom_load=True,
|
||||
'vit_base_patch16_siglip_256.webli': _cfg(
|
||||
hf_hub_id='timm/ViT-B-16-SigLIP-256',
|
||||
hf_hub_filename='open_clip_pytorch_model.bin',
|
||||
input_size=(3, 256, 256),
|
||||
# hf_hub_id='timm/',
|
||||
num_classes=0),
|
||||
'vit_base_patch16_siglip_384': _cfg(
|
||||
file='',
|
||||
custom_load=True,
|
||||
'vit_base_patch16_siglip_384.webli': _cfg(
|
||||
hf_hub_id='timm/ViT-B-16-SigLIP-384',
|
||||
hf_hub_filename='open_clip_pytorch_model.bin',
|
||||
input_size=(3, 384, 384),
|
||||
# hf_hub_id='timm/',
|
||||
num_classes=0),
|
||||
'vit_base_patch16_siglip_512': _cfg(
|
||||
file='',
|
||||
custom_load=True,
|
||||
'vit_base_patch16_siglip_512.webli': _cfg(
|
||||
hf_hub_id='timm/ViT-B-16-SigLIP-512',
|
||||
hf_hub_filename='open_clip_pytorch_model.bin',
|
||||
input_size=(3, 512, 512),
|
||||
# hf_hub_id='timm/',
|
||||
num_classes=0),
|
||||
'vit_large_patch16_siglip_256': _cfg(
|
||||
custom_load=True,
|
||||
'vit_large_patch16_siglip_256.webli': _cfg(
|
||||
hf_hub_id='timm/ViT-L-16-SigLIP-256',
|
||||
hf_hub_filename='open_clip_pytorch_model.bin',
|
||||
input_size=(3, 256, 256),
|
||||
# hf_hub_id='timm/',
|
||||
num_classes=0),
|
||||
'vit_large_patch16_siglip_384': _cfg(
|
||||
custom_load=True,
|
||||
'vit_large_patch16_siglip_384.webli': _cfg(
|
||||
hf_hub_id='timm/ViT-L-16-SigLIP-384',
|
||||
hf_hub_filename='open_clip_pytorch_model.bin',
|
||||
input_size=(3, 384, 384),
|
||||
# hf_hub_id='timm/',
|
||||
num_classes=0),
|
||||
'vit_so400m_patch14_siglip_224': _cfg(
|
||||
# file='/data/n/temp/siglip/webli_en_b16_256_60500360.npz',
|
||||
custom_load=True,
|
||||
# hf_hub_id='timm/',
|
||||
'vit_so400m_patch14_siglip_224.webli': _cfg(
|
||||
hf_hub_id='timm/ViT-SO400M-14-SigLIP',
|
||||
hf_hub_filename='open_clip_pytorch_model.bin',
|
||||
num_classes=0),
|
||||
'vit_so400m_patch14_siglip_384': _cfg(
|
||||
#file='/data/n/temp/siglip/webli_en_b16_256_60500360.npz',
|
||||
custom_load=True,
|
||||
# hf_hub_id='timm/',
|
||||
'vit_so400m_patch14_siglip_384.webli': _cfg(
|
||||
hf_hub_id='timm/ViT-SO400M-14-SigLIP-384',
|
||||
hf_hub_filename='open_clip_pytorch_model.bin',
|
||||
input_size=(3, 384, 384),
|
||||
num_classes=0),
|
||||
})
|
||||
|
Loading…
x
Reference in New Issue
Block a user