mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Merge pull request #1973 from huggingface/vit_siglip_and_reg
Working on support for SigLIP (w/ attn pool) ViT backbone and registers
This commit is contained in:
commit
e1e7cf5275
@ -1,6 +1,7 @@
|
||||
from .activations import *
|
||||
from .adaptive_avgmax_pool import \
|
||||
adaptive_avgmax_pool2d, select_adaptive_pool2d, AdaptiveAvgMaxPool2d, SelectAdaptivePool2d
|
||||
from .attention_pool import AttentionPoolLatent
|
||||
from .attention_pool2d import AttentionPool2d, RotAttentionPool2d, RotaryEmbedding
|
||||
from .blur_pool import BlurPool2d
|
||||
from .classifier import ClassifierHead, create_classifier, NormMlpClassifierHead
|
||||
|
103
timm/layers/attention_pool.py
Normal file
103
timm/layers/attention_pool.py
Normal file
@ -0,0 +1,103 @@
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from .config import use_fused_attn
|
||||
from .mlp import Mlp
|
||||
from .weight_init import trunc_normal_tf_
|
||||
|
||||
|
||||
class AttentionPoolLatent(nn.Module):
|
||||
""" Attention pooling w/ latent query
|
||||
"""
|
||||
fused_attn: torch.jit.Final[bool]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_features: int,
|
||||
out_features: int = None,
|
||||
embed_dim: int = None,
|
||||
num_heads: int = 8,
|
||||
mlp_ratio: float = 4.0,
|
||||
qkv_bias: bool = True,
|
||||
qk_norm: bool = False,
|
||||
latent_len: int = 1,
|
||||
latent_dim: int = None,
|
||||
pos_embed: str = '',
|
||||
pool_type: str = 'token',
|
||||
norm_layer: Optional[nn.Module] = None,
|
||||
drop: float = 0.0,
|
||||
):
|
||||
super().__init__()
|
||||
embed_dim = embed_dim or in_features
|
||||
out_features = out_features or in_features
|
||||
assert embed_dim % num_heads == 0
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = embed_dim // num_heads
|
||||
self.scale = self.head_dim ** -0.5
|
||||
self.pool = pool_type
|
||||
self.fused_attn = use_fused_attn()
|
||||
|
||||
if pos_embed == 'abs':
|
||||
spatial_len = self.feat_size
|
||||
self.pos_embed = nn.Parameter(torch.zeros(spatial_len, in_features))
|
||||
else:
|
||||
self.pos_embed = None
|
||||
|
||||
self.latent_dim = latent_dim or embed_dim
|
||||
self.latent_len = latent_len
|
||||
self.latent = nn.Parameter(torch.zeros(1, self.latent_len, embed_dim))
|
||||
|
||||
self.q = nn.Linear(embed_dim, embed_dim, bias=qkv_bias)
|
||||
self.kv = nn.Linear(embed_dim, embed_dim * 2, bias=qkv_bias)
|
||||
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
||||
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
||||
self.proj = nn.Linear(embed_dim, embed_dim)
|
||||
self.proj_drop = nn.Dropout(drop)
|
||||
|
||||
self.norm = norm_layer(out_features) if norm_layer is not None else nn.Identity()
|
||||
self.mlp = Mlp(embed_dim, int(embed_dim * mlp_ratio))
|
||||
|
||||
self.init_weights()
|
||||
|
||||
def init_weights(self):
|
||||
if self.pos_embed is not None:
|
||||
trunc_normal_tf_(self.pos_embed, std=self.pos_embed.shape[1] ** -0.5)
|
||||
trunc_normal_tf_(self.latent, std=self.latent_dim ** -0.5)
|
||||
|
||||
def forward(self, x):
|
||||
B, N, C = x.shape
|
||||
|
||||
if self.pos_embed is not None:
|
||||
# FIXME interpolate
|
||||
x = x + self.pos_embed.unsqueeze(0).to(x.dtype)
|
||||
|
||||
q_latent = self.latent.expand(B, -1, -1)
|
||||
q = self.q(q_latent).reshape(B, self.latent_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
kv = self.kv(x).reshape(B, N, 2, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
|
||||
k, v = kv.unbind(0)
|
||||
|
||||
q, k = self.q_norm(q), self.k_norm(k)
|
||||
|
||||
if self.fused_attn:
|
||||
x = F.scaled_dot_product_attention(q, k, v)
|
||||
else:
|
||||
q = q * self.scale
|
||||
attn = q @ k.transpose(-2, -1)
|
||||
attn = attn.softmax(dim=-1)
|
||||
x = attn @ v
|
||||
x = x.transpose(1, 2).reshape(B, self.latent_len, C)
|
||||
x = self.proj(x)
|
||||
x = self.proj_drop(x)
|
||||
|
||||
x = x + self.mlp(self.norm(x))
|
||||
|
||||
# optional pool if latent seq_len > 1 and pooled output is desired
|
||||
if self.pool == 'token':
|
||||
x = x[:, 0]
|
||||
elif self.pool == 'avg':
|
||||
x = x.mean(1)
|
||||
return x
|
@ -160,7 +160,11 @@ def load_pretrained(
|
||||
state_dict = pretrained_loc # pretrained_loc is the actual state dict for this override
|
||||
elif load_from == 'file':
|
||||
_logger.info(f'Loading pretrained weights from file ({pretrained_loc})')
|
||||
state_dict = load_state_dict(pretrained_loc)
|
||||
if pretrained_cfg.get('custom_load', False):
|
||||
model.load_pretrained(pretrained_loc)
|
||||
return
|
||||
else:
|
||||
state_dict = load_state_dict(pretrained_loc)
|
||||
elif load_from == 'url':
|
||||
_logger.info(f'Loading pretrained weights from url ({pretrained_loc})')
|
||||
if pretrained_cfg.get('custom_load', False):
|
||||
|
@ -376,7 +376,7 @@ def _get_safe_alternatives(filename: str) -> Iterable[str]:
|
||||
"""
|
||||
if filename == HF_WEIGHTS_NAME:
|
||||
yield HF_SAFE_WEIGHTS_NAME
|
||||
# if filename == HF_OPEN_CLIP_WEIGHTS_NAME: # FIXME tracking safetensors yet
|
||||
# yield HF_OPEN_CLIP_SAFE_WEIGHTS_NAME
|
||||
if filename == HF_OPEN_CLIP_WEIGHTS_NAME:
|
||||
yield HF_OPEN_CLIP_SAFE_WEIGHTS_NAME
|
||||
if filename not in (HF_WEIGHTS_NAME, HF_OPEN_CLIP_WEIGHTS_NAME) and filename.endswith(".bin"):
|
||||
yield filename[:-4] + ".safetensors"
|
||||
|
@ -37,8 +37,8 @@ from torch.jit import Final
|
||||
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD, \
|
||||
OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
|
||||
from timm.layers import PatchEmbed, Mlp, DropPath, trunc_normal_, lecun_normal_, resample_patch_embed, \
|
||||
resample_abs_pos_embed, RmsNorm, PatchDropout, use_fused_attn, SwiGLUPacked
|
||||
from timm.layers import PatchEmbed, Mlp, DropPath, AttentionPoolLatent, RmsNorm, PatchDropout, SwiGLUPacked, \
|
||||
trunc_normal_, lecun_normal_, resample_patch_embed, resample_abs_pos_embed, use_fused_attn
|
||||
from ._builder import build_model_with_cfg
|
||||
from ._manipulate import named_apply, checkpoint_seq, adapt_input_conv
|
||||
from ._registry import generate_default_cfgs, register_model, register_model_deprecations
|
||||
@ -401,6 +401,7 @@ class VisionTransformer(nn.Module):
|
||||
init_values: Optional[float] = None,
|
||||
class_token: bool = True,
|
||||
no_embed_class: bool = False,
|
||||
reg_tokens: int = 0,
|
||||
pre_norm: bool = False,
|
||||
fc_norm: Optional[bool] = None,
|
||||
dynamic_img_size: bool = False,
|
||||
@ -432,6 +433,8 @@ class VisionTransformer(nn.Module):
|
||||
qkv_bias: Enable bias for qkv projections if True.
|
||||
init_values: Layer-scale init values (layer-scale enabled if not None).
|
||||
class_token: Use class token.
|
||||
no_embed_class: Don't include position embeddings for class (or reg) tokens.
|
||||
reg_tokens: Number of register tokens.
|
||||
fc_norm: Pre head norm after pool (instead of before), if None, enabled when global_pool == 'avg'.
|
||||
drop_rate: Head dropout rate.
|
||||
pos_drop_rate: Position embedding dropout rate.
|
||||
@ -444,7 +447,7 @@ class VisionTransformer(nn.Module):
|
||||
block_fn: Transformer block layer.
|
||||
"""
|
||||
super().__init__()
|
||||
assert global_pool in ('', 'avg', 'token')
|
||||
assert global_pool in ('', 'avg', 'token', 'map')
|
||||
assert class_token or global_pool != 'token'
|
||||
use_fc_norm = global_pool == 'avg' if fc_norm is None else fc_norm
|
||||
norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
|
||||
@ -454,7 +457,10 @@ class VisionTransformer(nn.Module):
|
||||
self.global_pool = global_pool
|
||||
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
|
||||
self.num_prefix_tokens = 1 if class_token else 0
|
||||
self.no_embed_class = no_embed_class
|
||||
self.num_prefix_tokens += reg_tokens
|
||||
self.num_reg_tokens = reg_tokens
|
||||
self.has_class_token = class_token
|
||||
self.no_embed_class = no_embed_class # don't embed prefix positions (includes reg)
|
||||
self.dynamic_img_size = dynamic_img_size
|
||||
self.grad_checkpointing = False
|
||||
|
||||
@ -474,6 +480,7 @@ class VisionTransformer(nn.Module):
|
||||
num_patches = self.patch_embed.num_patches
|
||||
|
||||
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if class_token else None
|
||||
self.reg_token = nn.Parameter(torch.zeros(1, reg_tokens, embed_dim)) if reg_tokens else None
|
||||
embed_len = num_patches if no_embed_class else num_patches + self.num_prefix_tokens
|
||||
self.pos_embed = nn.Parameter(torch.randn(1, embed_len, embed_dim) * .02)
|
||||
self.pos_drop = nn.Dropout(p=pos_drop_rate)
|
||||
@ -506,6 +513,15 @@ class VisionTransformer(nn.Module):
|
||||
self.norm = norm_layer(embed_dim) if not use_fc_norm else nn.Identity()
|
||||
|
||||
# Classifier Head
|
||||
if global_pool == 'map':
|
||||
self.attn_pool = AttentionPoolLatent(
|
||||
self.embed_dim,
|
||||
num_heads=num_heads,
|
||||
mlp_ratio=mlp_ratio,
|
||||
norm_layer=norm_layer,
|
||||
)
|
||||
else:
|
||||
self.attn_pool = None
|
||||
self.fc_norm = norm_layer(embed_dim) if use_fc_norm else nn.Identity()
|
||||
self.head_drop = nn.Dropout(drop_rate)
|
||||
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
||||
@ -566,18 +582,26 @@ class VisionTransformer(nn.Module):
|
||||
x = x.view(B, -1, C)
|
||||
else:
|
||||
pos_embed = self.pos_embed
|
||||
|
||||
to_cat = []
|
||||
if self.cls_token is not None:
|
||||
to_cat.append(self.cls_token.expand(x.shape[0], -1, -1))
|
||||
if self.reg_token is not None:
|
||||
to_cat.append(self.reg_token.expand(x.shape[0], -1, -1))
|
||||
|
||||
if self.no_embed_class:
|
||||
# deit-3, updated JAX (big vision)
|
||||
# position embedding does not overlap with class token, add then concat
|
||||
x = x + pos_embed
|
||||
if self.cls_token is not None:
|
||||
x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
|
||||
if to_cat:
|
||||
x = torch.cat(to_cat + [x], dim=1)
|
||||
else:
|
||||
# original timm, JAX, and deit vit impl
|
||||
# pos_embed has entry for class token, concat then add
|
||||
if self.cls_token is not None:
|
||||
x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
|
||||
if to_cat:
|
||||
x = torch.cat(to_cat + [x], dim=1)
|
||||
x = x + pos_embed
|
||||
|
||||
return self.pos_drop(x)
|
||||
|
||||
def _intermediate_layers(
|
||||
@ -605,7 +629,7 @@ class VisionTransformer(nn.Module):
|
||||
x: torch.Tensor,
|
||||
n: Union[int, Sequence] = 1,
|
||||
reshape: bool = False,
|
||||
return_class_token: bool = False,
|
||||
return_prefix_tokens: bool = False,
|
||||
norm: bool = False,
|
||||
) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]:
|
||||
""" Intermediate layer accessor (NOTE: This is a WIP experiment).
|
||||
@ -615,7 +639,7 @@ class VisionTransformer(nn.Module):
|
||||
outputs = self._intermediate_layers(x, n)
|
||||
if norm:
|
||||
outputs = [self.norm(out) for out in outputs]
|
||||
class_tokens = [out[:, 0:self.num_prefix_tokens] for out in outputs]
|
||||
prefix_tokens = [out[:, 0:self.num_prefix_tokens] for out in outputs]
|
||||
outputs = [out[:, self.num_prefix_tokens:] for out in outputs]
|
||||
|
||||
if reshape:
|
||||
@ -625,8 +649,8 @@ class VisionTransformer(nn.Module):
|
||||
for out in outputs
|
||||
]
|
||||
|
||||
if return_class_token:
|
||||
return tuple(zip(outputs, class_tokens))
|
||||
if return_prefix_tokens:
|
||||
return tuple(zip(outputs, prefix_tokens))
|
||||
return tuple(outputs)
|
||||
|
||||
def forward_features(self, x):
|
||||
@ -642,8 +666,12 @@ class VisionTransformer(nn.Module):
|
||||
return x
|
||||
|
||||
def forward_head(self, x, pre_logits: bool = False):
|
||||
if self.global_pool:
|
||||
x = x[:, self.num_prefix_tokens:].mean(dim=1) if self.global_pool == 'avg' else x[:, 0]
|
||||
if self.attn_pool is not None:
|
||||
x = self.attn_pool(x)
|
||||
elif self.global_pool == 'avg':
|
||||
x = x[:, self.num_prefix_tokens:].mean(dim=1)
|
||||
elif self.global_pool:
|
||||
x = x[:, 0] # class token
|
||||
x = self.fc_norm(x)
|
||||
x = self.head_drop(x)
|
||||
return x if pre_logits else self.head(x)
|
||||
@ -767,6 +795,9 @@ def _load_weights(model: VisionTransformer, checkpoint_path: str, prefix: str =
|
||||
elif 'params/embedding/kernel' in w:
|
||||
prefix = 'params/'
|
||||
big_vision = True
|
||||
elif 'params/img/embedding/kernel' in w:
|
||||
prefix = 'params/img/'
|
||||
big_vision = True
|
||||
|
||||
if hasattr(model.patch_embed, 'backbone'):
|
||||
# hybrid
|
||||
@ -823,13 +854,33 @@ def _load_weights(model: VisionTransformer, checkpoint_path: str, prefix: str =
|
||||
model.pos_embed.copy_(pos_embed_w)
|
||||
model.norm.weight.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/scale']))
|
||||
model.norm.bias.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/bias']))
|
||||
if isinstance(model.head, nn.Linear) and model.head.bias.shape[0] == w[f'{prefix}head/bias'].shape[-1]:
|
||||
if (isinstance(model.head, nn.Linear) and
|
||||
f'{prefix}head/bias' in w and
|
||||
model.head.bias.shape[0] == w[f'{prefix}head/bias'].shape[-1]):
|
||||
model.head.weight.copy_(_n2p(w[f'{prefix}head/kernel']))
|
||||
model.head.bias.copy_(_n2p(w[f'{prefix}head/bias']))
|
||||
# NOTE representation layer has been removed, not used in latest 21k/1k pretrained weights
|
||||
# if isinstance(getattr(model.pre_logits, 'fc', None), nn.Linear) and f'{prefix}pre_logits/bias' in w:
|
||||
# model.pre_logits.fc.weight.copy_(_n2p(w[f'{prefix}pre_logits/kernel']))
|
||||
# model.pre_logits.fc.bias.copy_(_n2p(w[f'{prefix}pre_logits/bias']))
|
||||
if model.attn_pool is not None:
|
||||
block_prefix = f'{prefix}MAPHead_0/'
|
||||
mha_prefix = block_prefix + f'MultiHeadDotProductAttention_0/'
|
||||
model.attn_pool.latent.copy_(_n2p(w[f'{block_prefix}probe'], t=False))
|
||||
model.attn_pool.kv.weight.copy_(torch.cat([
|
||||
_n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T for n in ('key', 'value')]))
|
||||
model.attn_pool.kv.bias.copy_(torch.cat([
|
||||
_n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('key', 'value')]))
|
||||
model.attn_pool.q.weight.copy_(_n2p(w[f'{mha_prefix}query/kernel'], t=False).flatten(1).T)
|
||||
model.attn_pool.q.bias.copy_(_n2p(w[f'{mha_prefix}query/bias'], t=False).reshape(-1))
|
||||
model.attn_pool.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1))
|
||||
model.attn_pool.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias']))
|
||||
model.attn_pool.norm.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale']))
|
||||
model.attn_pool.norm.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias']))
|
||||
for r in range(2):
|
||||
getattr(model.attn_pool.mlp, f'fc{r + 1}').weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_0/Dense_{r}/kernel']))
|
||||
getattr(model.attn_pool.mlp, f'fc{r + 1}').bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_0/Dense_{r}/bias']))
|
||||
|
||||
mha_sub, b_sub, ln1_sub = (0, 0, 1) if big_vision else (1, 3, 2)
|
||||
for i, block in enumerate(model.blocks.children()):
|
||||
block_prefix = f'{prefix}Transformer/encoderblock_{i}/'
|
||||
@ -842,11 +893,11 @@ def _load_weights(model: VisionTransformer, checkpoint_path: str, prefix: str =
|
||||
_n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('query', 'key', 'value')]))
|
||||
block.attn.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1))
|
||||
block.attn.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias']))
|
||||
block.norm2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_{ln1_sub}/scale']))
|
||||
block.norm2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_{ln1_sub}/bias']))
|
||||
for r in range(2):
|
||||
getattr(block.mlp, f'fc{r + 1}').weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_{b_sub}/Dense_{r}/kernel']))
|
||||
getattr(block.mlp, f'fc{r + 1}').bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_{b_sub}/Dense_{r}/bias']))
|
||||
block.norm2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_{ln1_sub}/scale']))
|
||||
block.norm2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_{ln1_sub}/bias']))
|
||||
|
||||
|
||||
def _convert_openai_clip(state_dict, model):
|
||||
@ -899,17 +950,6 @@ def _convert_dinov2(state_dict, model):
|
||||
return out_dict
|
||||
|
||||
|
||||
def _convert_ijepa(state_dict, model):
|
||||
out_dict = {}
|
||||
for k, v in state_dict['encoder'].items():
|
||||
if k.startswith('module.'):
|
||||
k = k[7:]
|
||||
if k.startswith('norm.'):
|
||||
k = 'fc_norm.' + k[5:]
|
||||
out_dict[k] = v
|
||||
return out_dict
|
||||
|
||||
|
||||
def checkpoint_filter_fn(
|
||||
state_dict,
|
||||
model,
|
||||
@ -922,6 +962,7 @@ def checkpoint_filter_fn(
|
||||
out_dict = {}
|
||||
state_dict = state_dict.get('model', state_dict)
|
||||
state_dict = state_dict.get('state_dict', state_dict)
|
||||
prefix = ''
|
||||
|
||||
if 'visual.class_embedding' in state_dict:
|
||||
return _convert_openai_clip(state_dict, model)
|
||||
@ -930,7 +971,17 @@ def checkpoint_filter_fn(
|
||||
state_dict = _convert_dinov2(state_dict, model)
|
||||
|
||||
if "encoder" in state_dict:
|
||||
state_dict = _convert_ijepa(state_dict, model)
|
||||
state_dict = state_dict['encoder']
|
||||
prefix = 'module.'
|
||||
|
||||
if 'visual.trunk.pos_embed' in state_dict:
|
||||
# convert an OpenCLIP model with timm vision encoder
|
||||
# FIXME remap final nn.Linear if it exists outside of the timm .trunk (ie in visual.head.proj)
|
||||
prefix = 'visual.trunk.'
|
||||
|
||||
if prefix:
|
||||
# filter on & remove prefix string from keys
|
||||
state_dict = {k[len(prefix):]: v for k, v in state_dict.items() if k.startswith(prefix)}
|
||||
|
||||
for k, v in state_dict.items():
|
||||
if 'patch_embed.proj.weight' in k:
|
||||
@ -1472,27 +1523,72 @@ default_cfgs = generate_default_cfgs({
|
||||
license='cc-by-nc-4.0',
|
||||
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0),
|
||||
|
||||
'vit_huge_patch14_224_ijepa.in1k': _cfg(
|
||||
'vit_huge_patch14_gap_224.in1k_ijepa': _cfg(
|
||||
url='https://dl.fbaipublicfiles.com/ijepa/IN1K-vit.h.14-300e.pth.tar',
|
||||
# hf_hub_id='timm/',
|
||||
license='cc-by-nc-4.0',
|
||||
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0),
|
||||
'vit_huge_patch14_224_ijepa.in22k': _cfg(
|
||||
'vit_huge_patch14_gap_224.in22k_ijepa': _cfg(
|
||||
url='https://dl.fbaipublicfiles.com/ijepa/IN22K-vit.h.14-900e.pth.tar',
|
||||
# hf_hub_id='timm/',
|
||||
license='cc-by-nc-4.0',
|
||||
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0),
|
||||
'vit_huge_patch16_448_ijepa.in1k': _cfg(
|
||||
'vit_huge_patch16_gap_448.in1k_ijepa': _cfg(
|
||||
url='https://dl.fbaipublicfiles.com/ijepa/IN1K-vit.h.16-448px-300e.pth.tar',
|
||||
# hf_hub_id='timm/',
|
||||
license='cc-by-nc-4.0',
|
||||
input_size=(3, 448, 448), crop_pct=1.0,
|
||||
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0),
|
||||
'vit_gigantic_patch16_224_ijepa.in22k': _cfg(
|
||||
'vit_giant_patch16_gap_224.in22k_ijepa': _cfg(
|
||||
url='https://dl.fbaipublicfiles.com/ijepa/IN22K-vit.g.16-600e.pth.tar',
|
||||
# hf_hub_id='timm/',
|
||||
license='cc-by-nc-4.0',
|
||||
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0),
|
||||
|
||||
'vit_base_patch16_siglip_224.webli': _cfg(
|
||||
hf_hub_id='timm/ViT-B-16-SigLIP',
|
||||
hf_hub_filename='open_clip_pytorch_model.bin',
|
||||
num_classes=0),
|
||||
'vit_base_patch16_siglip_256.webli': _cfg(
|
||||
hf_hub_id='timm/ViT-B-16-SigLIP-256',
|
||||
hf_hub_filename='open_clip_pytorch_model.bin',
|
||||
input_size=(3, 256, 256),
|
||||
num_classes=0),
|
||||
'vit_base_patch16_siglip_384.webli': _cfg(
|
||||
hf_hub_id='timm/ViT-B-16-SigLIP-384',
|
||||
hf_hub_filename='open_clip_pytorch_model.bin',
|
||||
input_size=(3, 384, 384),
|
||||
num_classes=0),
|
||||
'vit_base_patch16_siglip_512.webli': _cfg(
|
||||
hf_hub_id='timm/ViT-B-16-SigLIP-512',
|
||||
hf_hub_filename='open_clip_pytorch_model.bin',
|
||||
input_size=(3, 512, 512),
|
||||
num_classes=0),
|
||||
'vit_large_patch16_siglip_256.webli': _cfg(
|
||||
hf_hub_id='timm/ViT-L-16-SigLIP-256',
|
||||
hf_hub_filename='open_clip_pytorch_model.bin',
|
||||
input_size=(3, 256, 256),
|
||||
num_classes=0),
|
||||
'vit_large_patch16_siglip_384.webli': _cfg(
|
||||
hf_hub_id='timm/ViT-L-16-SigLIP-384',
|
||||
hf_hub_filename='open_clip_pytorch_model.bin',
|
||||
input_size=(3, 384, 384),
|
||||
num_classes=0),
|
||||
'vit_so400m_patch14_siglip_224.webli': _cfg(
|
||||
hf_hub_id='timm/ViT-SO400M-14-SigLIP',
|
||||
hf_hub_filename='open_clip_pytorch_model.bin',
|
||||
num_classes=0),
|
||||
'vit_so400m_patch14_siglip_384.webli': _cfg(
|
||||
hf_hub_id='timm/ViT-SO400M-14-SigLIP-384',
|
||||
hf_hub_filename='open_clip_pytorch_model.bin',
|
||||
input_size=(3, 384, 384),
|
||||
num_classes=0),
|
||||
|
||||
'vit_medium_patch16_reg4_256': _cfg(
|
||||
input_size=(3, 256, 256)),
|
||||
'vit_medium_patch16_reg4_gap_256': _cfg(
|
||||
input_size=(3, 256, 256)),
|
||||
'vit_base_patch16_reg8_gap_256': _cfg(input_size=(3, 256, 256)),
|
||||
})
|
||||
|
||||
|
||||
@ -1754,7 +1850,7 @@ def vit_medium_patch16_gap_384(pretrained=False, **kwargs) -> VisionTransformer:
|
||||
|
||||
@register_model
|
||||
def vit_base_patch16_gap_224(pretrained=False, **kwargs) -> VisionTransformer:
|
||||
""" ViT-Base (ViT-B/16) w/o class token, w/ avg-pool @ 256x256
|
||||
""" ViT-Base (ViT-B/16) w/o class token, w/ avg-pool @ 224x224
|
||||
"""
|
||||
model_args = dict(
|
||||
patch_size=16, embed_dim=768, depth=12, num_heads=16, class_token=False, global_pool='avg', fc_norm=False)
|
||||
@ -1763,6 +1859,40 @@ def vit_base_patch16_gap_224(pretrained=False, **kwargs) -> VisionTransformer:
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_huge_patch14_gap_224(pretrained=False, **kwargs) -> VisionTransformer:
|
||||
""" ViT-Huge model (ViT-H/14) w/ no class token, avg pool
|
||||
"""
|
||||
model_args = dict(
|
||||
patch_size=14, embed_dim=1280, depth=32, num_heads=16, class_token=False, global_pool='avg', fc_norm=False)
|
||||
model = _create_vision_transformer(
|
||||
'vit_huge_patch14_gap_224', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_huge_patch16_gap_448(pretrained=False, **kwargs) -> VisionTransformer:
|
||||
""" ViT-Huge model (ViT-H/16) w/ no class token, avg pool @ 448x448
|
||||
"""
|
||||
model_args = dict(
|
||||
patch_size=16, embed_dim=1280, depth=32, num_heads=16, class_token=False, global_pool='avg', fc_norm=False)
|
||||
model = _create_vision_transformer(
|
||||
'vit_huge_patch16_gap_448', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_giant_patch16_gap_224(pretrained=False, **kwargs) -> VisionTransformer:
|
||||
""" ViT-Giant (little-gg) model (ViT-g/16) w/ no class token, avg pool
|
||||
"""
|
||||
model_args = dict(
|
||||
patch_size=16, embed_dim=1408, depth=40, num_heads=16, mlp_ratio=48/11,
|
||||
class_token=False, global_pool='avg', fc_norm=False)
|
||||
model = _create_vision_transformer(
|
||||
'vit_giant_patch16_gap_224', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_base_patch32_clip_224(pretrained=False, **kwargs) -> VisionTransformer:
|
||||
""" ViT-B/32 CLIP image tower @ 224x224
|
||||
@ -2089,33 +2219,115 @@ def vit_giant_patch14_dinov2(pretrained=False, **kwargs) -> VisionTransformer:
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_huge_patch14_224_ijepa(pretrained=False, **kwargs) -> VisionTransformer:
|
||||
""" ViT-Huge model (ViT-H/14) from `I-JEPA` - https://arxiv.org/abs/2301.08243
|
||||
"""
|
||||
model_args = dict(patch_size=14, embed_dim=1280, depth=32, num_heads=16, class_token=False, global_pool='avg')
|
||||
model = _create_vision_transformer(
|
||||
'vit_huge_patch14_224_ijepa', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_huge_patch16_448_ijepa(pretrained=False, **kwargs) -> VisionTransformer:
|
||||
""" ViT-Huge model (ViT-H/16) from `I-JEPA` - https://arxiv.org/abs/2301.08243
|
||||
"""
|
||||
def vit_base_patch16_siglip_224(pretrained=False, **kwargs) -> VisionTransformer:
|
||||
model_args = dict(
|
||||
patch_size=16, embed_dim=1280, depth=32, num_heads=16, class_token=False, global_pool='avg', img_size=448)
|
||||
patch_size=16, embed_dim=768, depth=12, num_heads=12, class_token=False, global_pool='map',
|
||||
)
|
||||
model = _create_vision_transformer(
|
||||
'vit_huge_patch16_448_ijepa', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
'vit_base_patch16_siglip_224', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_gigantic_patch16_224_ijepa(pretrained=False, **kwargs) -> VisionTransformer:
|
||||
""" ViT-Gigantic (big-G) model (ViT-G/16) from `I-JEPA - https://arxiv.org/abs/2301.08243
|
||||
"""
|
||||
model_args = dict(patch_size=16, embed_dim=1664, mlp_ratio=64/13, depth=48, num_heads=16)
|
||||
def vit_base_patch16_siglip_256(pretrained=False, **kwargs) -> VisionTransformer:
|
||||
model_args = dict(
|
||||
patch_size=16, embed_dim=768, depth=12, num_heads=12, class_token=False, global_pool='map',
|
||||
)
|
||||
model = _create_vision_transformer(
|
||||
'vit_gigantic_patch16_224_ijepa', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
'vit_base_patch16_siglip_256', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_base_patch16_siglip_384(pretrained=False, **kwargs) -> VisionTransformer:
|
||||
model_args = dict(
|
||||
patch_size=16, embed_dim=768, depth=12, num_heads=12, class_token=False, global_pool='map',
|
||||
)
|
||||
model = _create_vision_transformer(
|
||||
'vit_base_patch16_siglip_384', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_base_patch16_siglip_512(pretrained=False, **kwargs) -> VisionTransformer:
|
||||
model_args = dict(
|
||||
patch_size=16, embed_dim=768, depth=12, num_heads=12, class_token=False, global_pool='map',
|
||||
)
|
||||
model = _create_vision_transformer(
|
||||
'vit_base_patch16_siglip_512', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_large_patch16_siglip_256(pretrained=False, **kwargs) -> VisionTransformer:
|
||||
model_args = dict(
|
||||
patch_size=16, embed_dim=1024, depth=24, num_heads=16, class_token=False, global_pool='map',
|
||||
)
|
||||
model = _create_vision_transformer(
|
||||
'vit_large_patch16_siglip_256', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_large_patch16_siglip_384(pretrained=False, **kwargs) -> VisionTransformer:
|
||||
model_args = dict(
|
||||
patch_size=16, embed_dim=1024, depth=24, num_heads=16, class_token=False, global_pool='map',
|
||||
)
|
||||
model = _create_vision_transformer(
|
||||
'vit_large_patch16_siglip_384', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_so400m_patch14_siglip_224(pretrained=False, **kwargs) -> VisionTransformer:
|
||||
model_args = dict(
|
||||
patch_size=14, embed_dim=1152, depth=27, num_heads=16, mlp_ratio=3.7362, class_token=False, global_pool='map',
|
||||
)
|
||||
model = _create_vision_transformer(
|
||||
'vit_so400m_patch14_siglip_224', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_so400m_patch14_siglip_384(pretrained=False, **kwargs) -> VisionTransformer:
|
||||
model_args = dict(
|
||||
patch_size=14, embed_dim=1152, depth=27, num_heads=16, mlp_ratio=3.7362, class_token=False, global_pool='map',
|
||||
)
|
||||
model = _create_vision_transformer(
|
||||
'vit_so400m_patch14_siglip_384', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_medium_patch16_reg4_256(pretrained=False, **kwargs) -> VisionTransformer:
|
||||
model_args = dict(
|
||||
patch_size=16, embed_dim=512, depth=12, num_heads=8, class_token=True,
|
||||
no_embed_class=True, reg_tokens=4,
|
||||
)
|
||||
model = _create_vision_transformer(
|
||||
'vit_medium_patch16_reg4_256', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_medium_patch16_reg4_gap_256(pretrained=False, **kwargs) -> VisionTransformer:
|
||||
model_args = dict(
|
||||
patch_size=16, embed_dim=512, depth=12, num_heads=8,
|
||||
class_token=False, no_embed_class=True, reg_tokens=4, global_pool='avg',
|
||||
)
|
||||
model = _create_vision_transformer(
|
||||
'vit_medium_patch16_reg4_gap_256', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_base_patch16_reg8_gap_256(pretrained=False, **kwargs) -> VisionTransformer:
|
||||
model_args = dict(
|
||||
patch_size=16, embed_dim=768, depth=12, num_heads=12, class_token=False,
|
||||
no_embed_class=True, global_pool='avg', reg_tokens=8,
|
||||
)
|
||||
model = _create_vision_transformer(
|
||||
'vit_base_patch16_reg8_gap_256', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return model
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user