siglip2 weights on hub, fix forward_intermediates when no prefix tokens (& return prefix selected)

pull/2444/head
Ross Wightman 2025-02-21 12:46:14 -08:00 committed by Ross Wightman
parent f63a11cf81
commit a667d3d8f0
1 changed files with 34 additions and 31 deletions

View File

@ -769,11 +769,14 @@ class VisionTransformer(nn.Module):
# split prefix (e.g. class, distill) and spatial feature tokens
prefix_tokens = [y[:, 0:self.num_prefix_tokens] for y in intermediates]
intermediates = [y[:, self.num_prefix_tokens:] for y in intermediates]
else:
prefix_tokens = None
if reshape:
# reshape to BCHW output format
H, W = self.patch_embed.dynamic_feat_size((height, width))
intermediates = [y.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() for y in intermediates]
if not torch.jit.is_scripting() and return_prefix_tokens:
if not torch.jit.is_scripting() and return_prefix_tokens and prefix_tokens is not None:
# return_prefix not support in torchscript due to poor type handling
intermediates = list(zip(intermediates, prefix_tokens))
@ -1889,17 +1892,17 @@ default_cfgs = {
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0),
'vit_base_patch32_siglip_256.v2_webli': _cfg(
# hf_hub_id='timm/',
hf_hub_id='timm/',
input_size=(3, 256, 256),
num_classes=0),
'vit_base_patch16_siglip_224.v2_webli': _cfg(
# hf_hub_id='timm/',
hf_hub_id='timm/',
num_classes=0),
'vit_base_patch16_siglip_224.webli': _cfg(
hf_hub_id='timm/',
num_classes=0),
'vit_base_patch16_siglip_256.v2_webli': _cfg(
# hf_hub_id='timm/',
hf_hub_id='timm/',
input_size=(3, 256, 256),
num_classes=0),
'vit_base_patch16_siglip_256.webli': _cfg(
@ -1911,7 +1914,7 @@ default_cfgs = {
input_size=(3, 256, 256),
num_classes=0),
'vit_base_patch16_siglip_384.v2_webli': _cfg(
# hf_hub_id='timm/',
hf_hub_id='timm/',
input_size=(3, 384, 384),
num_classes=0),
'vit_base_patch16_siglip_384.webli': _cfg(
@ -1919,7 +1922,7 @@ default_cfgs = {
input_size=(3, 384, 384),
num_classes=0),
'vit_base_patch16_siglip_512.v2_webli': _cfg(
# hf_hub_id='timm/',
hf_hub_id='timm/',
input_size=(3, 512, 512),
num_classes=0),
'vit_base_patch16_siglip_512.webli': _cfg(
@ -1927,7 +1930,7 @@ default_cfgs = {
input_size=(3, 512, 512),
num_classes=0),
'vit_large_patch16_siglip_256.v2_webli': _cfg(
# hf_hub_id='timm/',
hf_hub_id='timm/',
input_size=(3, 256, 256),
num_classes=0),
'vit_large_patch16_siglip_256.webli': _cfg(
@ -1935,7 +1938,7 @@ default_cfgs = {
input_size=(3, 256, 256),
num_classes=0),
'vit_large_patch16_siglip_384.v2_webli': _cfg(
# hf_hub_id='timm/',
hf_hub_id='timm/',
input_size=(3, 384, 384),
num_classes=0),
'vit_large_patch16_siglip_384.webli': _cfg(
@ -1943,17 +1946,17 @@ default_cfgs = {
input_size=(3, 384, 384),
num_classes=0),
'vit_large_patch16_siglip_512.v2_webli': _cfg(
# hf_hub_id='timm/',
hf_hub_id='timm/',
input_size=(3, 512, 512),
num_classes=0),
'vit_so400m_patch14_siglip_224.v2_webli': _cfg(
# hf_hub_id='timm/',
hf_hub_id='timm/',
num_classes=0),
'vit_so400m_patch14_siglip_224.webli': _cfg(
hf_hub_id='timm/',
num_classes=0),
'vit_so400m_patch14_siglip_378.v2_webli': _cfg(
# hf_hub_id='timm/',
hf_hub_id='timm/',
input_size=(3, 378, 378),
num_classes=0),
'vit_so400m_patch14_siglip_378.webli': _cfg(
@ -1965,7 +1968,7 @@ default_cfgs = {
input_size=(3, 384, 384),
num_classes=0),
'vit_so400m_patch16_siglip_256.v2_webli': _cfg(
# hf_hub_id='timm/',
hf_hub_id='timm/',
input_size=(3, 256, 256),
num_classes=0),
'vit_so400m_patch16_siglip_256.webli_i18n': _cfg(
@ -1973,34 +1976,34 @@ default_cfgs = {
input_size=(3, 256, 256),
num_classes=0),
'vit_so400m_patch16_siglip_384.v2_webli': _cfg(
#hf_hub_id='timm/',
hf_hub_id='timm/',
input_size=(3, 384, 384),
num_classes=0),
'vit_so400m_patch16_siglip_512.v2_webli': _cfg(
#hf_hub_id='timm/',
hf_hub_id='timm/',
input_size=(3, 512, 512),
num_classes=0),
'vit_giantopt_patch16_siglip_256.v2_webli': _cfg(
# hf_hub_id='timm/',
hf_hub_id='timm/',
input_size=(3, 256, 256),
num_classes=0),
'vit_giantopt_patch16_siglip_384.v2_webli': _cfg(
# hf_hub_id='timm/',
hf_hub_id='timm/',
input_size=(3, 384, 384),
num_classes=0),
'vit_base_patch32_siglip_gap_256.v2_webli': _cfg(
# hf_hub_id='timm/',
hf_hub_id='timm/',
input_size=(3, 256, 256),
num_classes=0),
'vit_base_patch16_siglip_gap_224.v2_webli': _cfg(
# hf_hub_id='timm/',
hf_hub_id='timm/',
num_classes=0),
'vit_base_patch16_siglip_gap_224.webli': _cfg(
hf_hub_id='timm/',
num_classes=0),
'vit_base_patch16_siglip_gap_256.v2_webli': _cfg(
# hf_hub_id='timm/',
hf_hub_id='timm/',
input_size=(3, 256, 256),
num_classes=0),
'vit_base_patch16_siglip_gap_256.webli': _cfg(
@ -2012,7 +2015,7 @@ default_cfgs = {
input_size=(3, 256, 256),
num_classes=0),
'vit_base_patch16_siglip_gap_384.v2_webli': _cfg(
# hf_hub_id='timm/',
hf_hub_id='timm/',
input_size=(3, 384, 384),
num_classes=0),
'vit_base_patch16_siglip_gap_384.webli': _cfg(
@ -2020,7 +2023,7 @@ default_cfgs = {
input_size=(3, 384, 384),
num_classes=0),
'vit_base_patch16_siglip_gap_512.v2_webli': _cfg(
# hf_hub_id='timm/',
hf_hub_id='timm/',
input_size=(3, 512, 512),
num_classes=0),
'vit_base_patch16_siglip_gap_512.webli': _cfg(
@ -2028,7 +2031,7 @@ default_cfgs = {
input_size=(3, 512, 512),
num_classes=0),
'vit_large_patch16_siglip_gap_256.v2_webli': _cfg(
# hf_hub_id='timm/',
hf_hub_id='timm/',
input_size=(3, 256, 256),
num_classes=0),
'vit_large_patch16_siglip_gap_256.webli': _cfg(
@ -2036,7 +2039,7 @@ default_cfgs = {
input_size=(3, 256, 256),
num_classes=0),
'vit_large_patch16_siglip_gap_384.v2_webli': _cfg(
# hf_hub_id='timm/',
hf_hub_id='timm/',
input_size=(3, 384, 384),
num_classes=0),
'vit_large_patch16_siglip_gap_384.webli': _cfg(
@ -2044,11 +2047,11 @@ default_cfgs = {
input_size=(3, 384, 384),
num_classes=0),
'vit_large_patch16_siglip_gap_512.v2_webli': _cfg(
# hf_hub_id='timm/',
hf_hub_id='timm/',
input_size=(3, 512, 512),
num_classes=0),
'vit_so400m_patch14_siglip_gap_224.v2_webli': _cfg(
# hf_hub_id='timm/',
hf_hub_id='timm/',
num_classes=0),
'vit_so400m_patch14_siglip_gap_224.webli': _cfg(
hf_hub_id='timm/',
@ -2071,7 +2074,7 @@ default_cfgs = {
# custom_load='hf',
# num_classes=0),
'vit_so400m_patch14_siglip_gap_378.v2_webli': _cfg(
# hf_hub_id='timm/',
hf_hub_id='timm/',
input_size=(3, 378, 378),
num_classes=0),
'vit_so400m_patch14_siglip_gap_378.webli': _cfg(
@ -2147,7 +2150,7 @@ default_cfgs = {
# input_size=(3, 896, 896), crop_pct=1.0,
# num_classes=0),
'vit_so400m_patch16_siglip_gap_256.v2_webli': _cfg(
# hf_hub_id='timm/',
hf_hub_id='timm/',
input_size=(3, 256, 256),
num_classes=0),
'vit_so400m_patch16_siglip_gap_256.webli_i18n': _cfg(
@ -2155,19 +2158,19 @@ default_cfgs = {
input_size=(3, 256, 256),
num_classes=0),
'vit_so400m_patch16_siglip_gap_384.v2_webli': _cfg(
# hf_hub_id='timm/',
hf_hub_id='timm/',
input_size=(3, 384, 384),
num_classes=0),
'vit_so400m_patch16_siglip_gap_512.v2_webli': _cfg(
# hf_hub_id='timm/',
hf_hub_id='timm/',
input_size=(3, 512, 512),
num_classes=0),
'vit_giantopt_patch16_siglip_gap_256.v2_webli': _cfg(
# hf_hub_id='timm/',
hf_hub_id='timm/',
input_size=(3, 256, 256),
num_classes=0),
'vit_giantopt_patch16_siglip_gap_384.v2_webli': _cfg(
# hf_hub_id='timm/',
hf_hub_id='timm/',
input_size=(3, 384, 384),
num_classes=0),