sbb vit weights on hub, testing

This commit is contained in:
Ross Wightman 2024-05-10 17:15:01 -07:00
parent 3582ca499e
commit aa4d06a11c
2 changed files with 32 additions and 41 deletions

View File

@ -717,11 +717,10 @@ def checkpoint_filter_fn(
# fixed embedding no need to load buffer from checkpoint
continue
# FIXME here while import new weights, to remove
if k == 'cls_token':
print('DEBUG: cls token -> reg')
k = 'reg_token'
#v = v + state_dict['pos_embed'][0, :]
# FIXME here while importing new weights, to remove
# if k == 'cls_token':
# print('DEBUG: cls token -> reg')
# k = 'reg_token'
if 'patch_embed.proj.weight' in k:
_, _, H, W = model.patch_embed.proj.weight.shape
@ -952,25 +951,25 @@ default_cfgs = generate_default_cfgs({
),
'vit_medium_patch16_rope_reg1_gap_256.sbb_in1k': _cfg(
#hf_hub_id='timm/',
file='vit_medium_gap1_rope-in1k-20230920-5.pth',
hf_hub_id='timm/',
#file='vit_medium_gap1_rope-in1k-20230920-5.pth',
input_size=(3, 256, 256), crop_pct=0.95,
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)
),
'vit_mediumd_patch16_rope_reg1_gap_256.sbb_in1k': _cfg(
#hf_hub_id='timm/',
file='vit_mediumd_gap1_rope-in1k-20230926-5.pth',
hf_hub_id='timm/',
#file='vit_mediumd_gap1_rope-in1k-20230926-5.pth',
input_size=(3, 256, 256), crop_pct=0.95,
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)
),
'vit_betwixt_patch16_rope_reg4_gap_256.sbb_in1k': _cfg(
#hf_hub_id='timm/',
file='vit_betwixt_gap4_rope-in1k-20231005-5.pth',
hf_hub_id='timm/',
#file='vit_betwixt_gap4_rope-in1k-20231005-5.pth',
input_size=(3, 256, 256), crop_pct=0.95,
),
'vit_base_patch16_rope_reg1_gap_256.sbb_in1k': _cfg(
#hf_hub_id='timm/',
file='vit_base_gap1_rope-in1k-20230930-5.pth',
hf_hub_id='timm/',
#file='vit_base_gap1_rope-in1k-20230930-5.pth',
input_size=(3, 256, 256), crop_pct=0.95,
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)
),

View File

@ -428,7 +428,6 @@ class VisionTransformer(nn.Module):
act_layer: Optional[LayerType] = None,
block_fn: Type[nn.Module] = Block,
mlp_layer: Type[nn.Module] = Mlp,
repr_size = False,
) -> None:
"""
Args:
@ -537,14 +536,6 @@ class VisionTransformer(nn.Module):
)
else:
self.attn_pool = None
if repr_size:
repr_size = self.embed_dim if isinstance(repr_size, bool) else repr_size
self.repr = nn.Sequential(nn.Linear(self.embed_dim, repr_size), nn.Tanh())
embed_dim = repr_size
print(self.repr)
else:
self.repr = nn.Identity()
self.fc_norm = norm_layer(embed_dim) if use_fc_norm else nn.Identity()
self.head_drop = nn.Dropout(drop_rate)
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
@ -761,7 +752,6 @@ class VisionTransformer(nn.Module):
x = x[:, self.num_prefix_tokens:].mean(dim=1)
elif self.global_pool:
x = x[:, 0] # class token
x = self.repr(x)
x = self.fc_norm(x)
x = self.head_drop(x)
return x if pre_logits else self.head(x)
@ -1804,35 +1794,45 @@ default_cfgs = {
#file='',
input_size=(3, 256, 256), crop_pct=0.95),
'vit_pwee_patch16_reg1_gap_256.sbb_in1k': _cfg(
file='./vit_pwee-in1k-8.pth',
#file='./vit_pwee-in1k-8.pth',
hf_hub_id='timm/',
input_size=(3, 256, 256), crop_pct=0.95),
'vit_little_patch16_reg4_gap_256.sbb_in1k': _cfg(
file='vit_little_patch16-in1k-8a.pth',
#file='vit_little_patch16-in1k-8a.pth',
hf_hub_id='timm/',
input_size=(3, 256, 256), crop_pct=0.95),
'vit_medium_patch16_reg1_gap_256.sbb_in1k': _cfg(
file='vit_medium_gap1-in1k-20231118-8.pth',
#file='vit_medium_gap1-in1k-20231118-8.pth',
hf_hub_id='timm/',
input_size=(3, 256, 256), crop_pct=0.95),
'vit_medium_patch16_reg4_gap_256.sbb_in1k': _cfg(
file='vit_medium_gap4-in1k-20231115-8.pth',
#file='vit_medium_gap4-in1k-20231115-8.pth',
hf_hub_id='timm/',
input_size=(3, 256, 256), crop_pct=0.95),
'vit_mediumd_patch16_reg4_gap_256.sbb_in12k_ft_in1k': _cfg(
file='vit_mp_patch16_reg4-in1k-5a.pth',
#file='vit_mp_patch16_reg4-in1k-5a.pth',
hf_hub_id='timm/',
input_size=(3, 256, 256), crop_pct=0.95),
'vit_mediumd_patch16_reg4_gap_256.sbb_in12k': _cfg(
file='vit_mp_patch16_reg4-in12k-8.pth',
#file='vit_mp_patch16_reg4-in12k-8.pth',
hf_hub_id='timm/',
num_classes=11821,
input_size=(3, 256, 256), crop_pct=0.95),
'vit_betwixt_patch16_reg1_gap_256.sbb_in1k': _cfg(
file='vit_betwixt_gap1-in1k-20231121-8.pth',
#file='vit_betwixt_gap1-in1k-20231121-8.pth',
hf_hub_id='timm/',
input_size=(3, 256, 256), crop_pct=0.95),
'vit_betwixt_patch16_reg4_gap_256.sbb_in12k_ft_in1k': _cfg(
file='vit_betwixt_patch16_reg4-ft-in1k-8b.pth',
#file='vit_betwixt_patch16_reg4-ft-in1k-8b.pth',
hf_hub_id='timm/',
input_size=(3, 256, 256), crop_pct=0.95),
'vit_betwixt_patch16_reg4_gap_256.sbb_in1k': _cfg(
file='vit_betwixt_gap4-in1k-20231106-8.pth',
#file='vit_betwixt_gap4-in1k-20231106-8.pth',
hf_hub_id='timm/',
input_size=(3, 256, 256), crop_pct=0.95),
'vit_betwixt_patch16_reg4_gap_256.sbb_in12k': _cfg(
file='vit_betwixt_gap4-in12k-8.pth',
#file='vit_betwixt_gap4-in12k-8.pth',
hf_hub_id='timm/',
num_classes=11821,
input_size=(3, 256, 256), crop_pct=0.95),
'vit_base_patch16_reg4_gap_256': _cfg(
@ -1933,14 +1933,6 @@ def vit_small_patch16_224(pretrained: bool = False, **kwargs) -> VisionTransform
return model
@register_model
def vit_small_patch16_gap_224(pretrained: bool = False, **kwargs) -> VisionTransformer:
""" ViT-Small (ViT-S/16)
"""
model_args = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, global_pool='avg', class_token=False, repr_size=True)
model = _create_vision_transformer('vit_small_patch16_224', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@register_model
def vit_small_patch16_384(pretrained: bool = False, **kwargs) -> VisionTransformer:
""" ViT-Small (ViT-S/16)