mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Add SigLIP weights
This commit is contained in:
parent
42daa3b497
commit
71365165a2
103
timm/layers/attention_pool.py
Normal file
103
timm/layers/attention_pool.py
Normal file
@ -0,0 +1,103 @@
|
|||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
from .config import use_fused_attn
|
||||||
|
from .mlp import Mlp
|
||||||
|
from .weight_init import trunc_normal_tf_
|
||||||
|
|
||||||
|
|
||||||
|
class AttentionPoolLatent(nn.Module):
|
||||||
|
""" Attention pooling w/ latent query
|
||||||
|
"""
|
||||||
|
fused_attn: torch.jit.Final[bool]
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_features: int,
|
||||||
|
out_features: int = None,
|
||||||
|
embed_dim: int = None,
|
||||||
|
num_heads: int = 8,
|
||||||
|
mlp_ratio: float = 4.0,
|
||||||
|
qkv_bias: bool = True,
|
||||||
|
qk_norm: bool = False,
|
||||||
|
latent_len: int = 1,
|
||||||
|
latent_dim: int = None,
|
||||||
|
pos_embed: str = '',
|
||||||
|
pool_type: str = 'token',
|
||||||
|
norm_layer: Optional[nn.Module] = None,
|
||||||
|
drop: float = 0.0,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
embed_dim = embed_dim or in_features
|
||||||
|
out_features = out_features or in_features
|
||||||
|
assert embed_dim % num_heads == 0
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.head_dim = embed_dim // num_heads
|
||||||
|
self.scale = self.head_dim ** -0.5
|
||||||
|
self.pool = pool_type
|
||||||
|
self.fused_attn = use_fused_attn()
|
||||||
|
|
||||||
|
if pos_embed == 'abs':
|
||||||
|
spatial_len = self.feat_size
|
||||||
|
self.pos_embed = nn.Parameter(torch.zeros(spatial_len, in_features))
|
||||||
|
else:
|
||||||
|
self.pos_embed = None
|
||||||
|
|
||||||
|
self.latent_dim = latent_dim or embed_dim
|
||||||
|
self.latent_len = latent_len
|
||||||
|
self.latent = nn.Parameter(torch.zeros(1, self.latent_len, embed_dim))
|
||||||
|
|
||||||
|
self.q = nn.Linear(embed_dim, embed_dim, bias=qkv_bias)
|
||||||
|
self.kv = nn.Linear(embed_dim, embed_dim * 2, bias=qkv_bias)
|
||||||
|
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
||||||
|
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
||||||
|
self.proj = nn.Linear(embed_dim, embed_dim)
|
||||||
|
self.proj_drop = nn.Dropout(drop)
|
||||||
|
|
||||||
|
self.norm = norm_layer(out_features) if norm_layer is not None else nn.Identity()
|
||||||
|
self.mlp = Mlp(embed_dim, int(embed_dim * mlp_ratio))
|
||||||
|
|
||||||
|
self.init_weights()
|
||||||
|
|
||||||
|
def init_weights(self):
|
||||||
|
if self.pos_embed is not None:
|
||||||
|
trunc_normal_tf_(self.pos_embed, std=self.pos_embed.shape[1] ** -0.5)
|
||||||
|
trunc_normal_tf_(self.latent, std=self.latent_dim ** -0.5)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
B, N, C = x.shape
|
||||||
|
|
||||||
|
if self.pos_embed is not None:
|
||||||
|
# FIXME interpolate
|
||||||
|
x = x + self.pos_embed.unsqueeze(0).to(x.dtype)
|
||||||
|
|
||||||
|
q_latent = self.latent.expand(B, -1, -1)
|
||||||
|
q = self.q(q_latent).reshape(B, self.latent_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||||
|
|
||||||
|
kv = self.kv(x).reshape(B, N, 2, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
|
||||||
|
k, v = kv.unbind(0)
|
||||||
|
|
||||||
|
q, k = self.q_norm(q), self.k_norm(k)
|
||||||
|
|
||||||
|
if self.fused_attn:
|
||||||
|
x = F.scaled_dot_product_attention(q, k, v)
|
||||||
|
else:
|
||||||
|
q = q * self.scale
|
||||||
|
attn = q @ k.transpose(-2, -1)
|
||||||
|
attn = attn.softmax(dim=-1)
|
||||||
|
x = attn @ v
|
||||||
|
x = x.transpose(1, 2).reshape(B, self.latent_len, C)
|
||||||
|
x = self.proj(x)
|
||||||
|
x = self.proj_drop(x)
|
||||||
|
|
||||||
|
x = x + self.mlp(self.norm(x))
|
||||||
|
|
||||||
|
# optional pool if latent seq_len > 1 and pooled output is desired
|
||||||
|
if self.pool == 'token':
|
||||||
|
x = x[:, 0]
|
||||||
|
elif self.pool == 'avg':
|
||||||
|
x = x.mean(1)
|
||||||
|
return x
|
@ -376,7 +376,7 @@ def _get_safe_alternatives(filename: str) -> Iterable[str]:
|
|||||||
"""
|
"""
|
||||||
if filename == HF_WEIGHTS_NAME:
|
if filename == HF_WEIGHTS_NAME:
|
||||||
yield HF_SAFE_WEIGHTS_NAME
|
yield HF_SAFE_WEIGHTS_NAME
|
||||||
# if filename == HF_OPEN_CLIP_WEIGHTS_NAME: # FIXME tracking safetensors yet
|
if filename == HF_OPEN_CLIP_WEIGHTS_NAME:
|
||||||
# yield HF_OPEN_CLIP_SAFE_WEIGHTS_NAME
|
yield HF_OPEN_CLIP_SAFE_WEIGHTS_NAME
|
||||||
if filename not in (HF_WEIGHTS_NAME, HF_OPEN_CLIP_WEIGHTS_NAME) and filename.endswith(".bin"):
|
if filename not in (HF_WEIGHTS_NAME, HF_OPEN_CLIP_WEIGHTS_NAME) and filename.endswith(".bin"):
|
||||||
yield filename[:-4] + ".safetensors"
|
yield filename[:-4] + ".safetensors"
|
||||||
|
@ -37,8 +37,8 @@ from torch.jit import Final
|
|||||||
|
|
||||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD, \
|
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD, \
|
||||||
OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
|
OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
|
||||||
from timm.layers import PatchEmbed, Mlp, DropPath, trunc_normal_, lecun_normal_, resample_patch_embed, \
|
from timm.layers import PatchEmbed, Mlp, DropPath, AttentionPoolLatent, RmsNorm, PatchDropout, SwiGLUPacked, \
|
||||||
resample_abs_pos_embed, RmsNorm, PatchDropout, use_fused_attn, SwiGLUPacked
|
trunc_normal_, lecun_normal_, resample_patch_embed, resample_abs_pos_embed, use_fused_attn
|
||||||
from ._builder import build_model_with_cfg
|
from ._builder import build_model_with_cfg
|
||||||
from ._manipulate import named_apply, checkpoint_seq, adapt_input_conv
|
from ._manipulate import named_apply, checkpoint_seq, adapt_input_conv
|
||||||
from ._registry import generate_default_cfgs, register_model, register_model_deprecations
|
from ._registry import generate_default_cfgs, register_model, register_model_deprecations
|
||||||
@ -377,95 +377,6 @@ class ParallelThingsBlock(nn.Module):
|
|||||||
return self._forward(x)
|
return self._forward(x)
|
||||||
|
|
||||||
|
|
||||||
class AttentionPoolLatent(nn.Module):
|
|
||||||
""" Attention pooling w/ latent query
|
|
||||||
"""
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
in_features: int,
|
|
||||||
out_features: int = None,
|
|
||||||
embed_dim: int = None,
|
|
||||||
num_heads: int = 8,
|
|
||||||
mlp_ratio: float = 4.0,
|
|
||||||
qkv_bias: bool = True,
|
|
||||||
qk_norm: bool = False,
|
|
||||||
latent_len: int = 1,
|
|
||||||
latent_dim: int = None,
|
|
||||||
pos_embed: str = '',
|
|
||||||
pool_type: str = 'token',
|
|
||||||
norm_layer: Optional[nn.Module] = None,
|
|
||||||
drop: float = 0.0,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
embed_dim = embed_dim or in_features
|
|
||||||
out_features = out_features or in_features
|
|
||||||
assert embed_dim % num_heads == 0
|
|
||||||
self.num_heads = num_heads
|
|
||||||
self.head_dim = embed_dim // num_heads
|
|
||||||
self.scale = self.head_dim ** -0.5
|
|
||||||
self.pool = pool_type
|
|
||||||
self.fused_attn = use_fused_attn()
|
|
||||||
|
|
||||||
if pos_embed == 'abs':
|
|
||||||
spatial_len = self.feat_size
|
|
||||||
self.pos_embed = nn.Parameter(torch.zeros(spatial_len, in_features))
|
|
||||||
else:
|
|
||||||
self.pos_embed = None
|
|
||||||
|
|
||||||
self.latent_dim = latent_dim or embed_dim
|
|
||||||
self.latent_len = latent_len
|
|
||||||
self.latent = nn.Parameter(torch.zeros(1, self.latent_len, embed_dim))
|
|
||||||
|
|
||||||
self.q = nn.Linear(embed_dim, embed_dim, bias=qkv_bias)
|
|
||||||
self.kv = nn.Linear(embed_dim, embed_dim * 2, bias=qkv_bias)
|
|
||||||
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
|
||||||
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
|
||||||
self.proj = nn.Linear(embed_dim, embed_dim)
|
|
||||||
self.proj_drop = nn.Dropout(drop)
|
|
||||||
|
|
||||||
self.norm = norm_layer(out_features) if norm_layer is not None else nn.Identity()
|
|
||||||
self.mlp = Mlp(embed_dim, int(embed_dim * mlp_ratio))
|
|
||||||
|
|
||||||
def init_weights(self):
|
|
||||||
if self.pos_embed is not None:
|
|
||||||
trunc_normal_(self.pos_embed, std=self.pos_embed.shape[1] ** -0.5)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
B, N, C = x.shape
|
|
||||||
|
|
||||||
if self.pos_embed is not None:
|
|
||||||
# FIXME interpolate
|
|
||||||
x = x + self.pos_embed.unsqueeze(0).to(x.dtype)
|
|
||||||
|
|
||||||
q_latent = self.latent.expand(B, -1, -1)
|
|
||||||
q = self.q(q_latent).reshape(B, self.latent_len, self.num_heads, self.head_dim).transpose(1, 2)
|
|
||||||
|
|
||||||
kv = self.kv(x).reshape(B, N, 2, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
|
|
||||||
k, v = kv.unbind(0)
|
|
||||||
|
|
||||||
q, k = self.q_norm(q), self.k_norm(k)
|
|
||||||
|
|
||||||
if self.fused_attn:
|
|
||||||
x = F.scaled_dot_product_attention(q, k, v)
|
|
||||||
else:
|
|
||||||
q = q * self.scale
|
|
||||||
attn = q @ k.transpose(-2, -1)
|
|
||||||
attn = attn.softmax(dim=-1)
|
|
||||||
x = attn @ v
|
|
||||||
x = x.transpose(1, 2).reshape(B, self.latent_len, C)
|
|
||||||
x = self.proj(x)
|
|
||||||
x = self.proj_drop(x)
|
|
||||||
|
|
||||||
x = x + self.mlp(self.norm(x))
|
|
||||||
|
|
||||||
# optional pool if latent seq_len > 1 and pooled output is desired
|
|
||||||
if self.pool == 'token':
|
|
||||||
x = x[:, 0]
|
|
||||||
elif self.pool == 'avg':
|
|
||||||
x = x.mean(1)
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
class VisionTransformer(nn.Module):
|
class VisionTransformer(nn.Module):
|
||||||
""" Vision Transformer
|
""" Vision Transformer
|
||||||
|
|
||||||
@ -1072,6 +983,12 @@ def checkpoint_filter_fn(
|
|||||||
if "encoder" in state_dict:
|
if "encoder" in state_dict:
|
||||||
state_dict = _convert_ijepa(state_dict, model)
|
state_dict = _convert_ijepa(state_dict, model)
|
||||||
|
|
||||||
|
if 'visual.trunk.pos_embed' in state_dict:
|
||||||
|
# convert an OpenCLIP model with timm vision encoder
|
||||||
|
prefix = 'visual.trunk.'
|
||||||
|
state_dict = {k[len(prefix):]: v for k, v in state_dict.items() if k.startswith(prefix)}
|
||||||
|
# FIXME remap final nn.Linear if it exists outside of the timm .trunk (ie in visual.head.proj)
|
||||||
|
|
||||||
for k, v in state_dict.items():
|
for k, v in state_dict.items():
|
||||||
if 'patch_embed.proj.weight' in k:
|
if 'patch_embed.proj.weight' in k:
|
||||||
O, I, H, W = model.patch_embed.proj.weight.shape
|
O, I, H, W = model.patch_embed.proj.weight.shape
|
||||||
@ -1634,48 +1551,42 @@ default_cfgs = generate_default_cfgs({
|
|||||||
license='cc-by-nc-4.0',
|
license='cc-by-nc-4.0',
|
||||||
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0),
|
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0),
|
||||||
|
|
||||||
'vit_base_patch16_siglip_224': _cfg(
|
'vit_base_patch16_siglip_224.webli': _cfg(
|
||||||
file='/data/n/temp/siglip/webli_en_b16_224_63724782.npz',
|
hf_hub_id='timm/ViT-B-16-SigLIP',
|
||||||
custom_load=True,
|
hf_hub_filename='open_clip_pytorch_model.bin',
|
||||||
# hf_hub_id='timm/',
|
|
||||||
num_classes=0),
|
num_classes=0),
|
||||||
'vit_base_patch16_siglip_256': _cfg(
|
'vit_base_patch16_siglip_256.webli': _cfg(
|
||||||
file='/data/n/temp/siglip/webli_en_b16_256_60500360.npz',
|
hf_hub_id='timm/ViT-B-16-SigLIP-256',
|
||||||
custom_load=True,
|
hf_hub_filename='open_clip_pytorch_model.bin',
|
||||||
input_size=(3, 256, 256),
|
input_size=(3, 256, 256),
|
||||||
# hf_hub_id='timm/',
|
|
||||||
num_classes=0),
|
num_classes=0),
|
||||||
'vit_base_patch16_siglip_384': _cfg(
|
'vit_base_patch16_siglip_384.webli': _cfg(
|
||||||
file='',
|
hf_hub_id='timm/ViT-B-16-SigLIP-384',
|
||||||
custom_load=True,
|
hf_hub_filename='open_clip_pytorch_model.bin',
|
||||||
input_size=(3, 384, 384),
|
input_size=(3, 384, 384),
|
||||||
# hf_hub_id='timm/',
|
|
||||||
num_classes=0),
|
num_classes=0),
|
||||||
'vit_base_patch16_siglip_512': _cfg(
|
'vit_base_patch16_siglip_512.webli': _cfg(
|
||||||
file='',
|
hf_hub_id='timm/ViT-B-16-SigLIP-512',
|
||||||
custom_load=True,
|
hf_hub_filename='open_clip_pytorch_model.bin',
|
||||||
input_size=(3, 512, 512),
|
input_size=(3, 512, 512),
|
||||||
# hf_hub_id='timm/',
|
|
||||||
num_classes=0),
|
num_classes=0),
|
||||||
'vit_large_patch16_siglip_256': _cfg(
|
'vit_large_patch16_siglip_256.webli': _cfg(
|
||||||
custom_load=True,
|
hf_hub_id='timm/ViT-L-16-SigLIP-256',
|
||||||
|
hf_hub_filename='open_clip_pytorch_model.bin',
|
||||||
input_size=(3, 256, 256),
|
input_size=(3, 256, 256),
|
||||||
# hf_hub_id='timm/',
|
|
||||||
num_classes=0),
|
num_classes=0),
|
||||||
'vit_large_patch16_siglip_384': _cfg(
|
'vit_large_patch16_siglip_384.webli': _cfg(
|
||||||
custom_load=True,
|
hf_hub_id='timm/ViT-L-16-SigLIP-384',
|
||||||
|
hf_hub_filename='open_clip_pytorch_model.bin',
|
||||||
input_size=(3, 384, 384),
|
input_size=(3, 384, 384),
|
||||||
# hf_hub_id='timm/',
|
|
||||||
num_classes=0),
|
num_classes=0),
|
||||||
'vit_so400m_patch14_siglip_224': _cfg(
|
'vit_so400m_patch14_siglip_224.webli': _cfg(
|
||||||
# file='/data/n/temp/siglip/webli_en_b16_256_60500360.npz',
|
hf_hub_id='timm/ViT-SO400M-14-SigLIP',
|
||||||
custom_load=True,
|
hf_hub_filename='open_clip_pytorch_model.bin',
|
||||||
# hf_hub_id='timm/',
|
|
||||||
num_classes=0),
|
num_classes=0),
|
||||||
'vit_so400m_patch14_siglip_384': _cfg(
|
'vit_so400m_patch14_siglip_384.webli': _cfg(
|
||||||
#file='/data/n/temp/siglip/webli_en_b16_256_60500360.npz',
|
hf_hub_id='timm/ViT-SO400M-14-SigLIP-384',
|
||||||
custom_load=True,
|
hf_hub_filename='open_clip_pytorch_model.bin',
|
||||||
# hf_hub_id='timm/',
|
|
||||||
input_size=(3, 384, 384),
|
input_size=(3, 384, 384),
|
||||||
num_classes=0),
|
num_classes=0),
|
||||||
})
|
})
|
||||||
|
Loading…
x
Reference in New Issue
Block a user