Add more pali(2) weights. Switch rest of models adapting open_clip weights to their own weight instances.

This commit is contained in:
Ross Wightman 2024-12-27 12:05:22 -08:00
parent 4f4f40baa6
commit 5cf022f228
4 changed files with 178 additions and 107 deletions

View File

@ -2282,56 +2282,48 @@ default_cfgs = generate_default_cfgs({
# original attention pool head variants
'resnet50_clip.openai': _cfgr(
hf_hub_id='timm/',
hf_hub_filename='open_clip_pytorch_model.bin',
num_classes=1024, mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
fixed_input_size=True, input_size=(3, 224, 224), pool_size=(7, 7),
classifier='head.proj',
),
'resnet101_clip.openai': _cfgr(
hf_hub_id='timm/',
hf_hub_filename='open_clip_pytorch_model.bin',
num_classes=512, mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
fixed_input_size=True, input_size=(3, 224, 224), pool_size=(7, 7),
classifier='head.proj',
),
'resnet50x4_clip.openai': _cfgr(
hf_hub_id='timm/',
hf_hub_filename='open_clip_pytorch_model.bin',
num_classes=640, mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
fixed_input_size=True, input_size=(3, 288, 288), pool_size=(9, 9),
classifier='head.proj',
),
'resnet50x16_clip.openai': _cfgr(
hf_hub_id='timm/',
hf_hub_filename='open_clip_pytorch_model.bin',
num_classes=768, mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
fixed_input_size=True, input_size=(3, 384, 384), pool_size=(12, 12),
classifier='head.proj',
),
'resnet50x64_clip.openai': _cfgr(
hf_hub_id='timm/',
hf_hub_filename='open_clip_pytorch_model.bin',
num_classes=1024, mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
fixed_input_size=True, input_size=(3, 448, 448), pool_size=(14, 14),
classifier='head.proj',
),
'resnet50_clip.cc12m': _cfgr(
hf_hub_id='timm/',
hf_hub_filename='open_clip_pytorch_model.bin',
num_classes=1024, mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
fixed_input_size=True, input_size=(3, 224, 224), pool_size=(7, 7),
classifier='head.proj',
),
'resnet50_clip.yfcc15m': _cfgr(
hf_hub_id='timm/',
hf_hub_filename='open_clip_pytorch_model.bin',
num_classes=1024, mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
fixed_input_size=True, input_size=(3, 224, 224), pool_size=(7, 7),
classifier='head.proj',
),
'resnet101_clip.yfcc15m': _cfgr(
hf_hub_id='timm/',
hf_hub_filename='open_clip_pytorch_model.bin',
num_classes=512, mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
fixed_input_size=True, input_size=(3, 224, 224), pool_size=(7, 7),
classifier='head.proj',
@ -2339,50 +2331,42 @@ default_cfgs = generate_default_cfgs({
# avg-pool w/ optional standard classifier head variants
'resnet50_clip_gap.openai': _cfgr(
hf_hub_id='timm/resnet50_clip.openai',
hf_hub_filename='open_clip_pytorch_model.bin',
hf_hub_id='timm/',
num_classes=0, mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
input_size=(3, 224, 224), pool_size=(7, 7),
),
'resnet101_clip_gap.openai': _cfgr(
hf_hub_id='timm/resnet101_clip.openai',
hf_hub_filename='open_clip_pytorch_model.bin',
hf_hub_id='timm/',
num_classes=0, mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
input_size=(3, 224, 224), pool_size=(7, 7),
),
'resnet50x4_clip_gap.openai': _cfgr(
hf_hub_id='timm/resnet50x4_clip.openai',
hf_hub_filename='open_clip_pytorch_model.bin',
hf_hub_id='timm/',
num_classes=0, mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
input_size=(3, 288, 288), pool_size=(9, 9),
),
'resnet50x16_clip_gap.openai': _cfgr(
hf_hub_id='timm/resnet50x16_clip.openai',
hf_hub_filename='open_clip_pytorch_model.bin',
hf_hub_id='timm/',
num_classes=0, mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
input_size=(3, 384, 384), pool_size=(12, 12),
),
'resnet50x64_clip_gap.openai': _cfgr(
hf_hub_id='timm/resnet50x64_clip.openai',
hf_hub_filename='open_clip_pytorch_model.bin',
hf_hub_id='timm/',
num_classes=0, mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
input_size=(3, 448, 448), pool_size=(14, 14),
),
'resnet50_clip_gap.cc12m': _cfgr(
hf_hub_id='timm/resnet50_clip.cc12m',
hf_hub_filename='open_clip_pytorch_model.bin',
hf_hub_id='timm/',
num_classes=0, mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
input_size=(3, 224, 224), pool_size=(7, 7),
),
'resnet50_clip_gap.yfcc15m': _cfgr(
hf_hub_id='timm/resnet50_clip.yfcc15m',
hf_hub_filename='open_clip_pytorch_model.bin',
hf_hub_id='timm/',
num_classes=0, mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
input_size=(3, 224, 224), pool_size=(7, 7),
),
'resnet101_clip_gap.yfcc15m': _cfgr(
hf_hub_id='timm/resnet101_clip.yfcc15m',
hf_hub_filename='open_clip_pytorch_model.bin',
hf_hub_id='timm/',
num_classes=0, mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
input_size=(3, 224, 224), pool_size=(7, 7),
),

View File

@ -912,45 +912,52 @@ default_cfgs = generate_default_cfgs({
# EVA01 and EVA02 CLIP image towers
'eva_giant_patch14_clip_224.laion400m': _cfg(
# hf_hub_id='QuanSun/EVA-CLIP', hf_hub_filename='EVA01_CLIP_g_14_plus_psz14_s11B.pt',
hf_hub_id='timm/eva_giant_patch14_clip_224.laion400m_s11b_b41k', # float16 weights
hf_hub_filename='open_clip_pytorch_model.bin',
# hf_hub_id='timm/eva_giant_patch14_clip_224.laion400m_s11b_b41k', # float16 weights
# hf_hub_filename='open_clip_pytorch_model.bin',
hf_hub_id='timm/',
num_classes=1024,
),
'eva_giant_patch14_clip_224.merged2b': _cfg(
# hf_hub_id='QuanSun/EVA-CLIP', hf_hub_filename='EVA01_CLIP_g_14_plus_psz14_s11B.pt',
hf_hub_id='timm/eva_giant_patch14_plus_clip_224.merged2b_s11b_b114k', # float16 weights
hf_hub_filename='open_clip_pytorch_model.bin',
# hf_hub_id='timm/eva_giant_patch14_plus_clip_224.merged2b_s11b_b114k', # float16 weights
# hf_hub_filename='open_clip_pytorch_model.bin',
hf_hub_id='timm/',
num_classes=1024,
),
'eva02_base_patch16_clip_224.merged2b': _cfg(
# hf_hub_id='QuanSun/EVA-CLIP', hf_hub_filename='EVA02_CLIP_L_psz14_s4B.pt',
hf_hub_id='timm/eva02_base_patch16_clip_224.merged2b_s8b_b131k', # float16 weights
hf_hub_filename='open_clip_pytorch_model.bin',
# hf_hub_id='timm/eva02_base_patch16_clip_224.merged2b_s8b_b131k', # float16 weights
# hf_hub_filename='open_clip_pytorch_model.bin',
hf_hub_id='timm/',
num_classes=512,
),
'eva02_large_patch14_clip_224.merged2b': _cfg(
# hf_hub_id='QuanSun/EVA-CLIP', hf_hub_filename='EVA02_CLIP_L_psz14_s4B.pt',
hf_hub_id='timm/eva02_large_patch14_clip_224.merged2b_s4b_b131k', # float16 weights
hf_hub_filename='open_clip_pytorch_model.bin',
# hf_hub_id='timm/eva02_large_patch14_clip_224.merged2b_s4b_b131k', # float16 weights
# hf_hub_filename='open_clip_pytorch_model.bin',
hf_hub_id='timm/',
num_classes=768,
),
'eva02_large_patch14_clip_336.merged2b': _cfg(
# hf_hub_id='QuanSun/EVA-CLIP', hf_hub_filename='EVA02_CLIP_L_psz14_s4B.pt',
hf_hub_id='timm/eva02_large_patch14_clip_336.merged2b_s6b_b61k', # float16 weights
hf_hub_filename='open_clip_pytorch_model.bin',
# hf_hub_id='timm/eva02_large_patch14_clip_336.merged2b_s6b_b61k', # float16 weights
# hf_hub_filename='open_clip_pytorch_model.bin',
hf_hub_id='timm/',
input_size=(3, 336, 336), crop_pct=1.0,
num_classes=768,
),
'eva02_enormous_patch14_clip_224.laion2b': _cfg(
# hf_hub_id='QuanSun/EVA-CLIP', hf_hub_filename='EVA02_CLIP_E_psz14_plus_s9B.pt',
hf_hub_id='timm/eva02_enormous_patch14_clip_224.laion2b_s4b_b115k', # float16 weights
hf_hub_filename='open_clip_pytorch_model.bin',
# hf_hub_id='timm/eva02_enormous_patch14_clip_224.laion2b_s4b_b115k', # float16 weights
# hf_hub_filename='open_clip_pytorch_model.bin',
hf_hub_id='timm/',
num_classes=1024,
),
'eva02_enormous_patch14_clip_224.laion2b_plus': _cfg(
# hf_hub_id='QuanSun/EVA-CLIP', hf_hub_filename='EVA02_CLIP_E_psz14_plus_s9B.pt',
hf_hub_id='timm/eva02_enormous_patch14_plus_clip_224.laion2b_s9b_b144k', # bfloat16 weights
hf_hub_filename='open_clip_pytorch_model.bin',
# hf_hub_id='timm/eva02_enormous_patch14_plus_clip_224.laion2b_s9b_b144k', # bfloat16 weights
# hf_hub_filename='open_clip_pytorch_model.bin',
hf_hub_id='timm/',
num_classes=1024,
),
'eva02_enormous_patch14_clip_224.pretrain': _cfg(

View File

@ -530,26 +530,47 @@ def _cfg(url='', **kwargs):
default_cfgs = generate_default_cfgs({
"sam2_hiera_tiny.r224": _cfg(
hf_hub_id='facebook/sam2-hiera-tiny',
hf_hub_filename='sam2_hiera_tiny.pt',
input_size=(3, 224, 224), pool_size=(7, 7),
), # FIXME reduced res for testing
"sam2_hiera_tiny.r896": _cfg(
hf_hub_id='facebook/sam2-hiera-tiny',
hf_hub_filename='sam2_hiera_tiny.pt',
"sam2_hiera_tiny.fb_r896": _cfg(
# hf_hub_id='facebook/sam2-hiera-tiny',
# hf_hub_filename='sam2_hiera_tiny.pt',
hf_hub_id='timm/',
),
"sam2_hiera_small": _cfg(
hf_hub_id='facebook/sam2-hiera-small',
hf_hub_filename='sam2_hiera_small.pt',
"sam2_hiera_tiny.fb_r896_2pt1": _cfg(
# hf_hub_id='facebook/sam2.1-hiera-tiny',
# hf_hub_filename='sam2.1_hiera_tiny.pt',
hf_hub_id='timm/',
),
"sam2_hiera_base_plus": _cfg(
hf_hub_id='facebook/sam2-hiera-base-plus',
hf_hub_filename='sam2_hiera_base_plus.pt',
"sam2_hiera_small.fb_r896": _cfg(
# hf_hub_id='facebook/sam2-hiera-small',
# hf_hub_filename='sam2_hiera_small.pt',
hf_hub_id='timm/',
),
"sam2_hiera_large": _cfg(
hf_hub_id='facebook/sam2-hiera-large',
hf_hub_filename='sam2_hiera_large.pt',
"sam2_hiera_small.fb_r896_2pt1": _cfg(
# hf_hub_id='facebook/sam2.1-hiera-small',
# hf_hub_filename='sam2.1_hiera_small.pt',
hf_hub_id='timm/',
),
"sam2_hiera_base_plus.fb_r896": _cfg(
# hf_hub_id='facebook/sam2-hiera-base-plus',
# hf_hub_filename='sam2_hiera_base_plus.pt',
hf_hub_id='timm/',
),
"sam2_hiera_base_plus.fb_r896_2pt1": _cfg(
# hf_hub_id='facebook/sam2.1-hiera-base-plus',
# hf_hub_filename='sam2.1_hiera_base_plus.pt',
hf_hub_id='timm/',
),
"sam2_hiera_large.fb_r1024": _cfg(
# hf_hub_id='facebook/sam2-hiera-large',
# hf_hub_filename='sam2_hiera_large.pt',
hf_hub_id='timm/',
min_input_size=(3, 256, 256),
input_size=(3, 1024, 1024), pool_size=(32, 32),
),
"sam2_hiera_large.fb_r1024_2pt1": _cfg(
# hf_hub_id='facebook/sam2.1-hiera-large',
# hf_hub_filename='sam2.1_hiera_large.pt',
hf_hub_id='timm/',
min_input_size=(3, 256, 256),
input_size=(3, 1024, 1024), pool_size=(32, 32),
),
@ -578,11 +599,11 @@ def checkpoint_filter_fn(state_dict, model=None, prefix=''):
def _create_hiera_det(variant: str, pretrained: bool = False, **kwargs) -> HieraDet:
out_indices = kwargs.pop('out_indices', 4)
checkpoint_prefix = ''
if 'sam2' in variant:
# SAM2 pretrained weights have no classifier or final norm-layer (`head.norm`)
# This is workaround loading with num_classes=0 w/o removing norm-layer.
kwargs.setdefault('pretrained_strict', False)
checkpoint_prefix = 'image_encoder.trunk.'
# if 'sam2' in variant:
# # SAM2 pretrained weights have no classifier or final norm-layer (`head.norm`)
# # This is workaround loading with num_classes=0 w/o removing norm-layer.
# kwargs.setdefault('pretrained_strict', False)
# checkpoint_prefix = 'image_encoder.trunk.'
return build_model_with_cfg(
HieraDet,
variant,

View File

@ -912,26 +912,40 @@ def resize_pos_embed(
@torch.no_grad()
def _load_weights(model: VisionTransformer, checkpoint_path: str, prefix: str = '') -> None:
def _load_weights(model: VisionTransformer, checkpoint_path: str, prefix: str = '', load_bfloat16: bool = False) -> None:
""" Load weights from .npz checkpoints for official Google Brain Flax implementation
"""
import numpy as np
if load_bfloat16:
import jax.numpy as jnp
import ml_dtypes
def _n2p(w, t=True, idx=None):
def _n2p(_w, t=True, idx=None):
if idx is not None:
w = w[idx]
if w.ndim == 4 and w.shape[0] == w.shape[1] == w.shape[2] == 1:
w = w.flatten()
if t:
if w.ndim == 4:
w = w.transpose([3, 2, 0, 1])
elif w.ndim == 3:
w = w.transpose([2, 0, 1])
elif w.ndim == 2:
w = w.transpose([1, 0])
return torch.from_numpy(w)
_w = _w[idx]
if load_bfloat16:
_w = _w.view(ml_dtypes.bfloat16).astype(jnp.float32)
_w = np.array(_w)
if _w.ndim == 4 and _w.shape[0] == _w.shape[1] == _w.shape[2] == 1:
_w = _w.flatten()
if t:
if _w.ndim == 4:
_w = _w.transpose([3, 2, 0, 1])
elif _w.ndim == 3:
_w = _w.transpose([2, 0, 1])
elif _w.ndim == 2:
_w = _w.transpose([1, 0])
_w = torch.from_numpy(_w)
return _w
if load_bfloat16:
w = jnp.load(checkpoint_path)
else:
w = np.load(checkpoint_path)
w = np.load(checkpoint_path)
interpolation = 'bilinear'
antialias = False
big_vision = False
@ -1593,18 +1607,18 @@ default_cfgs = {
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=1280),
'vit_base_patch32_clip_224.laion400m_e32': _cfg(
hf_hub_id='timm/', hf_hub_filename='open_clip_pytorch_model.bin',
hf_hub_id='timm/',
notes=('natively QuickGELU, use quickgelu model variant for original results',),
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=512),
'vit_base_patch16_clip_224.laion400m_e32': _cfg(
hf_hub_id='timm/', hf_hub_filename='open_clip_pytorch_model.bin',
hf_hub_id='timm/',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=512),
'vit_base_patch16_plus_clip_240.laion400m_e32': _cfg(
hf_hub_id='timm/', hf_hub_filename='open_clip_pytorch_model.bin',
hf_hub_id='timm/',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
input_size=(3, 240, 240), crop_pct=1.0, num_classes=512),
input_size=(3, 240, 240), crop_pct=1.0, num_classes=640),
'vit_large_patch14_clip_224.laion400m_e32': _cfg(
hf_hub_id='timm/', hf_hub_filename='open_clip_pytorch_model.bin',
hf_hub_id='timm/',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=768),
'vit_base_patch32_clip_224.datacompxl': _cfg(
@ -1622,22 +1636,18 @@ default_cfgs = {
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=768),
'vit_base_patch16_clip_224.dfn2b': _cfg(
hf_hub_id='apple/DFN2B-CLIP-ViT-B-16',
hf_hub_filename='open_clip_pytorch_model.bin',
hf_hub_id='timm/',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=512),
'vit_large_patch14_clip_224.dfn2b': _cfg(
hf_hub_id='apple/DFN2B-CLIP-ViT-L-14',
hf_hub_filename='open_clip_pytorch_model.bin',
hf_hub_id='timm/',
notes=('natively QuickGELU, use quickgelu model variant for original results',),
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=768),
'vit_huge_patch14_clip_224.dfn5b': _cfg(
hf_hub_id='apple/DFN5B-CLIP-ViT-H-14',
hf_hub_filename='open_clip_pytorch_model.bin',
hf_hub_id='timm/',
notes=('natively QuickGELU, use quickgelu model variant for original results',),
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=1024),
'vit_huge_patch14_clip_378.dfn5b': _cfg(
hf_hub_id='apple/DFN5B-CLIP-ViT-H-14-378',
hf_hub_filename='open_clip_pytorch_model.bin',
hf_hub_id='timm/',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
notes=('natively QuickGELU, use quickgelu model variant for original results',),
crop_pct=1.0, input_size=(3, 378, 378), num_classes=1024),
@ -1700,7 +1710,7 @@ default_cfgs = {
notes=('natively QuickGELU, use quickgelu model variant for original results',),
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=768),
'vit_large_patch14_clip_336.openai': _cfg(
hf_hub_id='timm/', hf_hub_filename='open_clip_pytorch_model.bin',
hf_hub_id='timm/',
notes=('natively QuickGELU, use quickgelu model variant for original results',),
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
crop_pct=1.0, input_size=(3, 336, 336), num_classes=768),
@ -1907,15 +1917,22 @@ default_cfgs = {
hf_hub_id='timm/',
num_classes=0),
'vit_so400m_patch14_siglip_gap_224.pali_mix': _cfg(
hf_hub_id='google/paligemma-3b-mix-224-jax',
hf_hub_filename='paligemma-3b-mix-224.npz',
custom_load='hf',
hf_hub_id='timm/',
num_classes=0),
'vit_so400m_patch14_siglip_gap_224.pali_pt': _cfg(
hf_hub_id='google/paligemma-3b-pt-224-jax',
hf_hub_filename='paligemma-3b-pt-224.npz',
custom_load='hf',
hf_hub_id='timm/',
num_classes=0),
'vit_so400m_patch14_siglip_gap_224.pali2_3b_pt': _cfg(
hf_hub_id='timm/',
num_classes=0),
'vit_so400m_patch14_siglip_gap_224.pali2_10b_pt': _cfg(
hf_hub_id='timm/',
num_classes=0),
# 'vit_so400m_patch14_siglip_gap_224.pali2_28b_pt': _cfg(
# hf_hub_id='google/paligemma2-28b-pt-224-jax',
# hf_hub_filename='pt_27b_224.npz',
# custom_load='hf',
# num_classes=0),
'vit_so400m_patch16_siglip_gap_256.webli_i18n': _cfg(
hf_hub_id='timm/',
input_size=(3, 256, 256),
@ -1929,23 +1946,69 @@ default_cfgs = {
input_size=(3, 384, 384), crop_pct=1.0,
num_classes=0),
'vit_so400m_patch14_siglip_gap_448.pali_mix': _cfg(
hf_hub_id='google/paligemma-3b-mix-448-jax',
hf_hub_filename='paligemma-3b-mix-448.npz',
custom_load='hf',
hf_hub_id='timm/',
input_size=(3, 448, 448), crop_pct=1.0,
num_classes=0),
'vit_so400m_patch14_siglip_gap_448.pali_pt': _cfg(
hf_hub_id='google/paligemma-3b-pt-448-jax',
hf_hub_filename='paligemma-3b-pt-448.npz',
custom_load='hf',
hf_hub_id='timm/',
input_size=(3, 448, 448), crop_pct=1.0,
num_classes=0),
'vit_so400m_patch14_siglip_gap_448.pali_refcoco_seg': _cfg(
hf_hub_id='timm/',
input_size=(3, 448, 448), crop_pct=1.0,
num_classes=0),
'vit_so400m_patch14_siglip_gap_448.pali_ocrvqa': _cfg(
hf_hub_id='timm/',
input_size=(3, 448, 448), crop_pct=1.0,
num_classes=0),
'vit_so400m_patch14_siglip_gap_448.pali2_3b_pt': _cfg(
hf_hub_id='timm/',
input_size=(3, 448, 448), crop_pct=1.0,
num_classes=0),
'vit_so400m_patch14_siglip_gap_448.pali2_10b_pt': _cfg(
hf_hub_id='timm/',
input_size=(3, 448, 448), crop_pct=1.0,
num_classes=0),
# 'vit_so400m_patch14_siglip_gap_448.pali2_28b_pt': _cfg(
# hf_hub_id='google/paligemma2-28b-pt-448-jax',
# hf_hub_filename='pt_27b_448.npz',
# custom_load='hf',
# input_size=(3, 448, 448), crop_pct=1.0,
# num_classes=0),
'vit_so400m_patch14_siglip_gap_448.pali2_3b_docci': _cfg(
hf_hub_id='timm/',
input_size=(3, 448, 448), crop_pct=1.0,
num_classes=0),
'vit_so400m_patch14_siglip_gap_448.pali2_10b_docci': _cfg(
hf_hub_id='timm/',
input_size=(3, 448, 448), crop_pct=1.0,
num_classes=0),
'vit_so400m_patch14_siglip_gap_896.pali_pt': _cfg(
hf_hub_id='google/paligemma-3b-pt-896-jax',
hf_hub_filename='paligemma-3b-pt-896.npz',
custom_load='hf',
hf_hub_id='timm/',
input_size=(3, 896, 896), crop_pct=1.0,
num_classes=0),
'vit_so400m_patch14_siglip_gap_896.pali_refcoco_seg': _cfg(
hf_hub_id='timm/',
input_size=(3, 896, 896), crop_pct=1.0,
num_classes=0),
'vit_so400m_patch14_siglip_gap_896.pali_ocrvqa': _cfg(
hf_hub_id='timm/',
input_size=(3, 896, 896), crop_pct=1.0,
num_classes=0),
'vit_so400m_patch14_siglip_gap_896.pali2_3b_pt': _cfg(
hf_hub_id='timm/',
input_size=(3, 896, 896), crop_pct=1.0,
num_classes=0),
'vit_so400m_patch14_siglip_gap_896.pali2_10b_pt': _cfg(
hf_hub_id='timm/',
input_size=(3, 896, 896), crop_pct=1.0,
num_classes=0),
# 'vit_so400m_patch14_siglip_gap_896.pali2_28b_pt': _cfg(
# hf_hub_id='google/paligemma2-28b-pt-896-jax',
# hf_hub_filename='pt_27b_896.npz',
# custom_load='hf',
# input_size=(3, 896, 896), crop_pct=1.0,
# num_classes=0),
'vit_so400m_patch14_siglip_378.webli_ft_in1k': _cfg(
hf_hub_id='timm/',
@ -1958,22 +2021,18 @@ default_cfgs = {
'vit_xsmall_patch16_clip_224.tinyclip_yfcc15m': _cfg(
hf_hub_id='timm/',
hf_hub_filename='open_clip_pytorch_model.bin',
license='mit',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=512),
'vit_medium_patch32_clip_224.tinyclip_laion400m': _cfg(
hf_hub_id='timm/',
hf_hub_filename='open_clip_pytorch_model.bin',
license='mit',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=512),
'vit_medium_patch16_clip_224.tinyclip_yfcc15m': _cfg(
hf_hub_id='timm/',
hf_hub_filename='open_clip_pytorch_model.bin',
license='mit',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=512),
'vit_betwixt_patch32_clip_224.tinyclip_laion400m': _cfg(
hf_hub_id='timm/',
hf_hub_filename='open_clip_pytorch_model.bin',
license='mit',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=512),