Add SigLIP weights

This commit is contained in:
Ross Wightman 2023-10-16 23:26:08 -07:00
parent 42daa3b497
commit 71365165a2
3 changed files with 137 additions and 123 deletions

View File

@ -0,0 +1,103 @@
from typing import Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from .config import use_fused_attn
from .mlp import Mlp
from .weight_init import trunc_normal_tf_
class AttentionPoolLatent(nn.Module):
""" Attention pooling w/ latent query
"""
fused_attn: torch.jit.Final[bool]
def __init__(
self,
in_features: int,
out_features: int = None,
embed_dim: int = None,
num_heads: int = 8,
mlp_ratio: float = 4.0,
qkv_bias: bool = True,
qk_norm: bool = False,
latent_len: int = 1,
latent_dim: int = None,
pos_embed: str = '',
pool_type: str = 'token',
norm_layer: Optional[nn.Module] = None,
drop: float = 0.0,
):
super().__init__()
embed_dim = embed_dim or in_features
out_features = out_features or in_features
assert embed_dim % num_heads == 0
self.num_heads = num_heads
self.head_dim = embed_dim // num_heads
self.scale = self.head_dim ** -0.5
self.pool = pool_type
self.fused_attn = use_fused_attn()
if pos_embed == 'abs':
spatial_len = self.feat_size
self.pos_embed = nn.Parameter(torch.zeros(spatial_len, in_features))
else:
self.pos_embed = None
self.latent_dim = latent_dim or embed_dim
self.latent_len = latent_len
self.latent = nn.Parameter(torch.zeros(1, self.latent_len, embed_dim))
self.q = nn.Linear(embed_dim, embed_dim, bias=qkv_bias)
self.kv = nn.Linear(embed_dim, embed_dim * 2, bias=qkv_bias)
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
self.proj = nn.Linear(embed_dim, embed_dim)
self.proj_drop = nn.Dropout(drop)
self.norm = norm_layer(out_features) if norm_layer is not None else nn.Identity()
self.mlp = Mlp(embed_dim, int(embed_dim * mlp_ratio))
self.init_weights()
def init_weights(self):
if self.pos_embed is not None:
trunc_normal_tf_(self.pos_embed, std=self.pos_embed.shape[1] ** -0.5)
trunc_normal_tf_(self.latent, std=self.latent_dim ** -0.5)
def forward(self, x):
B, N, C = x.shape
if self.pos_embed is not None:
# FIXME interpolate
x = x + self.pos_embed.unsqueeze(0).to(x.dtype)
q_latent = self.latent.expand(B, -1, -1)
q = self.q(q_latent).reshape(B, self.latent_len, self.num_heads, self.head_dim).transpose(1, 2)
kv = self.kv(x).reshape(B, N, 2, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
k, v = kv.unbind(0)
q, k = self.q_norm(q), self.k_norm(k)
if self.fused_attn:
x = F.scaled_dot_product_attention(q, k, v)
else:
q = q * self.scale
attn = q @ k.transpose(-2, -1)
attn = attn.softmax(dim=-1)
x = attn @ v
x = x.transpose(1, 2).reshape(B, self.latent_len, C)
x = self.proj(x)
x = self.proj_drop(x)
x = x + self.mlp(self.norm(x))
# optional pool if latent seq_len > 1 and pooled output is desired
if self.pool == 'token':
x = x[:, 0]
elif self.pool == 'avg':
x = x.mean(1)
return x

View File

@ -376,7 +376,7 @@ def _get_safe_alternatives(filename: str) -> Iterable[str]:
"""
if filename == HF_WEIGHTS_NAME:
yield HF_SAFE_WEIGHTS_NAME
# if filename == HF_OPEN_CLIP_WEIGHTS_NAME: # FIXME tracking safetensors yet
# yield HF_OPEN_CLIP_SAFE_WEIGHTS_NAME
if filename == HF_OPEN_CLIP_WEIGHTS_NAME:
yield HF_OPEN_CLIP_SAFE_WEIGHTS_NAME
if filename not in (HF_WEIGHTS_NAME, HF_OPEN_CLIP_WEIGHTS_NAME) and filename.endswith(".bin"):
yield filename[:-4] + ".safetensors"

View File

@ -37,8 +37,8 @@ from torch.jit import Final
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD, \
OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
from timm.layers import PatchEmbed, Mlp, DropPath, trunc_normal_, lecun_normal_, resample_patch_embed, \
resample_abs_pos_embed, RmsNorm, PatchDropout, use_fused_attn, SwiGLUPacked
from timm.layers import PatchEmbed, Mlp, DropPath, AttentionPoolLatent, RmsNorm, PatchDropout, SwiGLUPacked, \
trunc_normal_, lecun_normal_, resample_patch_embed, resample_abs_pos_embed, use_fused_attn
from ._builder import build_model_with_cfg
from ._manipulate import named_apply, checkpoint_seq, adapt_input_conv
from ._registry import generate_default_cfgs, register_model, register_model_deprecations
@ -377,95 +377,6 @@ class ParallelThingsBlock(nn.Module):
return self._forward(x)
class AttentionPoolLatent(nn.Module):
""" Attention pooling w/ latent query
"""
def __init__(
self,
in_features: int,
out_features: int = None,
embed_dim: int = None,
num_heads: int = 8,
mlp_ratio: float = 4.0,
qkv_bias: bool = True,
qk_norm: bool = False,
latent_len: int = 1,
latent_dim: int = None,
pos_embed: str = '',
pool_type: str = 'token',
norm_layer: Optional[nn.Module] = None,
drop: float = 0.0,
):
super().__init__()
embed_dim = embed_dim or in_features
out_features = out_features or in_features
assert embed_dim % num_heads == 0
self.num_heads = num_heads
self.head_dim = embed_dim // num_heads
self.scale = self.head_dim ** -0.5
self.pool = pool_type
self.fused_attn = use_fused_attn()
if pos_embed == 'abs':
spatial_len = self.feat_size
self.pos_embed = nn.Parameter(torch.zeros(spatial_len, in_features))
else:
self.pos_embed = None
self.latent_dim = latent_dim or embed_dim
self.latent_len = latent_len
self.latent = nn.Parameter(torch.zeros(1, self.latent_len, embed_dim))
self.q = nn.Linear(embed_dim, embed_dim, bias=qkv_bias)
self.kv = nn.Linear(embed_dim, embed_dim * 2, bias=qkv_bias)
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
self.proj = nn.Linear(embed_dim, embed_dim)
self.proj_drop = nn.Dropout(drop)
self.norm = norm_layer(out_features) if norm_layer is not None else nn.Identity()
self.mlp = Mlp(embed_dim, int(embed_dim * mlp_ratio))
def init_weights(self):
if self.pos_embed is not None:
trunc_normal_(self.pos_embed, std=self.pos_embed.shape[1] ** -0.5)
def forward(self, x):
B, N, C = x.shape
if self.pos_embed is not None:
# FIXME interpolate
x = x + self.pos_embed.unsqueeze(0).to(x.dtype)
q_latent = self.latent.expand(B, -1, -1)
q = self.q(q_latent).reshape(B, self.latent_len, self.num_heads, self.head_dim).transpose(1, 2)
kv = self.kv(x).reshape(B, N, 2, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
k, v = kv.unbind(0)
q, k = self.q_norm(q), self.k_norm(k)
if self.fused_attn:
x = F.scaled_dot_product_attention(q, k, v)
else:
q = q * self.scale
attn = q @ k.transpose(-2, -1)
attn = attn.softmax(dim=-1)
x = attn @ v
x = x.transpose(1, 2).reshape(B, self.latent_len, C)
x = self.proj(x)
x = self.proj_drop(x)
x = x + self.mlp(self.norm(x))
# optional pool if latent seq_len > 1 and pooled output is desired
if self.pool == 'token':
x = x[:, 0]
elif self.pool == 'avg':
x = x.mean(1)
return x
class VisionTransformer(nn.Module):
""" Vision Transformer
@ -1072,6 +983,12 @@ def checkpoint_filter_fn(
if "encoder" in state_dict:
state_dict = _convert_ijepa(state_dict, model)
if 'visual.trunk.pos_embed' in state_dict:
# convert an OpenCLIP model with timm vision encoder
prefix = 'visual.trunk.'
state_dict = {k[len(prefix):]: v for k, v in state_dict.items() if k.startswith(prefix)}
# FIXME remap final nn.Linear if it exists outside of the timm .trunk (ie in visual.head.proj)
for k, v in state_dict.items():
if 'patch_embed.proj.weight' in k:
O, I, H, W = model.patch_embed.proj.weight.shape
@ -1634,48 +1551,42 @@ default_cfgs = generate_default_cfgs({
license='cc-by-nc-4.0',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0),
'vit_base_patch16_siglip_224': _cfg(
file='/data/n/temp/siglip/webli_en_b16_224_63724782.npz',
custom_load=True,
# hf_hub_id='timm/',
'vit_base_patch16_siglip_224.webli': _cfg(
hf_hub_id='timm/ViT-B-16-SigLIP',
hf_hub_filename='open_clip_pytorch_model.bin',
num_classes=0),
'vit_base_patch16_siglip_256': _cfg(
file='/data/n/temp/siglip/webli_en_b16_256_60500360.npz',
custom_load=True,
'vit_base_patch16_siglip_256.webli': _cfg(
hf_hub_id='timm/ViT-B-16-SigLIP-256',
hf_hub_filename='open_clip_pytorch_model.bin',
input_size=(3, 256, 256),
# hf_hub_id='timm/',
num_classes=0),
'vit_base_patch16_siglip_384': _cfg(
file='',
custom_load=True,
'vit_base_patch16_siglip_384.webli': _cfg(
hf_hub_id='timm/ViT-B-16-SigLIP-384',
hf_hub_filename='open_clip_pytorch_model.bin',
input_size=(3, 384, 384),
# hf_hub_id='timm/',
num_classes=0),
'vit_base_patch16_siglip_512': _cfg(
file='',
custom_load=True,
'vit_base_patch16_siglip_512.webli': _cfg(
hf_hub_id='timm/ViT-B-16-SigLIP-512',
hf_hub_filename='open_clip_pytorch_model.bin',
input_size=(3, 512, 512),
# hf_hub_id='timm/',
num_classes=0),
'vit_large_patch16_siglip_256': _cfg(
custom_load=True,
'vit_large_patch16_siglip_256.webli': _cfg(
hf_hub_id='timm/ViT-L-16-SigLIP-256',
hf_hub_filename='open_clip_pytorch_model.bin',
input_size=(3, 256, 256),
# hf_hub_id='timm/',
num_classes=0),
'vit_large_patch16_siglip_384': _cfg(
custom_load=True,
'vit_large_patch16_siglip_384.webli': _cfg(
hf_hub_id='timm/ViT-L-16-SigLIP-384',
hf_hub_filename='open_clip_pytorch_model.bin',
input_size=(3, 384, 384),
# hf_hub_id='timm/',
num_classes=0),
'vit_so400m_patch14_siglip_224': _cfg(
# file='/data/n/temp/siglip/webli_en_b16_256_60500360.npz',
custom_load=True,
# hf_hub_id='timm/',
'vit_so400m_patch14_siglip_224.webli': _cfg(
hf_hub_id='timm/ViT-SO400M-14-SigLIP',
hf_hub_filename='open_clip_pytorch_model.bin',
num_classes=0),
'vit_so400m_patch14_siglip_384': _cfg(
#file='/data/n/temp/siglip/webli_en_b16_256_60500360.npz',
custom_load=True,
# hf_hub_id='timm/',
'vit_so400m_patch14_siglip_384.webli': _cfg(
hf_hub_id='timm/ViT-SO400M-14-SigLIP-384',
hf_hub_filename='open_clip_pytorch_model.bin',
input_size=(3, 384, 384),
num_classes=0),
})