From 3582ca499ed5e878b2deff1eb505b12666e2e36c Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Fri, 10 May 2024 14:14:06 -0700 Subject: [PATCH] Prepping weight push, benchmarking. --- timm/models/eva.py | 24 ++++----- timm/models/vision_transformer.py | 83 ++++++++++++++++++++++++++----- timm/version.py | 2 +- 3 files changed, 83 insertions(+), 26 deletions(-) diff --git a/timm/models/eva.py b/timm/models/eva.py index e2eeed60..406d6b9f 100644 --- a/timm/models/eva.py +++ b/timm/models/eva.py @@ -718,10 +718,10 @@ def checkpoint_filter_fn( continue # FIXME here while import new weights, to remove - # if k == 'cls_token': - # print('DEBUG: cls token -> reg') - # k = 'reg_token' - # #v = v + state_dict['pos_embed'][0, :] + if k == 'cls_token': + print('DEBUG: cls token -> reg') + k = 'reg_token' + #v = v + state_dict['pos_embed'][0, :] if 'patch_embed.proj.weight' in k: _, _, H, W = model.patch_embed.proj.weight.shape @@ -951,26 +951,26 @@ default_cfgs = generate_default_cfgs({ num_classes=0, ), - 'vit_medium_patch16_rope_reg1_gap_256.in1k': _cfg( + 'vit_medium_patch16_rope_reg1_gap_256.sbb_in1k': _cfg( #hf_hub_id='timm/', - #file='vit_medium_gap1_rope-in1k-20230920-5.pth', + file='vit_medium_gap1_rope-in1k-20230920-5.pth', input_size=(3, 256, 256), crop_pct=0.95, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5) ), - 'vit_mediumd_patch16_rope_reg1_gap_256.in1k': _cfg( + 'vit_mediumd_patch16_rope_reg1_gap_256.sbb_in1k': _cfg( #hf_hub_id='timm/', - #file='vit_mediumd_gap1_rope-in1k-20230926-5.pth', + file='vit_mediumd_gap1_rope-in1k-20230926-5.pth', input_size=(3, 256, 256), crop_pct=0.95, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5) ), - 'vit_betwixt_patch16_rope_reg4_gap_256.in1k': _cfg( + 'vit_betwixt_patch16_rope_reg4_gap_256.sbb_in1k': _cfg( #hf_hub_id='timm/', - #file='vit_betwixt_gap4_rope-in1k-20231005-5.pth', + file='vit_betwixt_gap4_rope-in1k-20231005-5.pth', input_size=(3, 256, 256), crop_pct=0.95, ), - 'vit_base_patch16_rope_reg1_gap_256.in1k': _cfg( + 'vit_base_patch16_rope_reg1_gap_256.sbb_in1k': _cfg( #hf_hub_id='timm/', - #file='vit_base_gap1_rope-in1k-20230930-5.pth', + file='vit_base_gap1_rope-in1k-20230930-5.pth', input_size=(3, 256, 256), crop_pct=0.95, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5) ), diff --git a/timm/models/vision_transformer.py b/timm/models/vision_transformer.py index 61f3e6eb..2dc6ee0c 100644 --- a/timm/models/vision_transformer.py +++ b/timm/models/vision_transformer.py @@ -428,6 +428,7 @@ class VisionTransformer(nn.Module): act_layer: Optional[LayerType] = None, block_fn: Type[nn.Module] = Block, mlp_layer: Type[nn.Module] = Mlp, + repr_size = False, ) -> None: """ Args: @@ -536,6 +537,14 @@ class VisionTransformer(nn.Module): ) else: self.attn_pool = None + if repr_size: + repr_size = self.embed_dim if isinstance(repr_size, bool) else repr_size + self.repr = nn.Sequential(nn.Linear(self.embed_dim, repr_size), nn.Tanh()) + embed_dim = repr_size + print(self.repr) + else: + self.repr = nn.Identity() + self.fc_norm = norm_layer(embed_dim) if use_fc_norm else nn.Identity() self.head_drop = nn.Dropout(drop_rate) self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() @@ -752,6 +761,7 @@ class VisionTransformer(nn.Module): x = x[:, self.num_prefix_tokens:].mean(dim=1) elif self.global_pool: x = x[:, 0] # class token + x = self.repr(x) x = self.fc_norm(x) x = self.head_drop(x) return x if pre_logits else self.head(x) @@ -1790,23 +1800,40 @@ default_cfgs = { license='mit', mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=512), - 'vit_wee_patch16_reg1_gap_256': _cfg( + 'vit_wee_patch16_reg1_gap_256.sbb_in1k': _cfg( #file='', input_size=(3, 256, 256), crop_pct=0.95), - 'vit_little_patch16_reg4_gap_256': _cfg( - #file='', + 'vit_pwee_patch16_reg1_gap_256.sbb_in1k': _cfg( + file='./vit_pwee-in1k-8.pth', input_size=(3, 256, 256), crop_pct=0.95), - 'vit_medium_patch16_reg1_gap_256': _cfg( - #file='vit_medium_gap1-in1k-20231118-8.pth', + 'vit_little_patch16_reg4_gap_256.sbb_in1k': _cfg( + file='vit_little_patch16-in1k-8a.pth', input_size=(3, 256, 256), crop_pct=0.95), - 'vit_medium_patch16_reg4_gap_256': _cfg( - #file='vit_medium_gap4-in1k-20231115-8.pth', + 'vit_medium_patch16_reg1_gap_256.sbb_in1k': _cfg( + file='vit_medium_gap1-in1k-20231118-8.pth', input_size=(3, 256, 256), crop_pct=0.95), - 'vit_betwixt_patch16_reg1_gap_256': _cfg( - #file='vit_betwixt_gap1-in1k-20231121-8.pth', + 'vit_medium_patch16_reg4_gap_256.sbb_in1k': _cfg( + file='vit_medium_gap4-in1k-20231115-8.pth', input_size=(3, 256, 256), crop_pct=0.95), - 'vit_betwixt_patch16_reg4_gap_256': _cfg( - #file='vit_betwixt_gap4-in1k-20231106-8.pth', + 'vit_mediumd_patch16_reg4_gap_256.sbb_in12k_ft_in1k': _cfg( + file='vit_mp_patch16_reg4-in1k-5a.pth', + input_size=(3, 256, 256), crop_pct=0.95), + 'vit_mediumd_patch16_reg4_gap_256.sbb_in12k': _cfg( + file='vit_mp_patch16_reg4-in12k-8.pth', + num_classes=11821, + input_size=(3, 256, 256), crop_pct=0.95), + 'vit_betwixt_patch16_reg1_gap_256.sbb_in1k': _cfg( + file='vit_betwixt_gap1-in1k-20231121-8.pth', + input_size=(3, 256, 256), crop_pct=0.95), + 'vit_betwixt_patch16_reg4_gap_256.sbb_in12k_ft_in1k': _cfg( + file='vit_betwixt_patch16_reg4-ft-in1k-8b.pth', + input_size=(3, 256, 256), crop_pct=0.95), + 'vit_betwixt_patch16_reg4_gap_256.sbb_in1k': _cfg( + file='vit_betwixt_gap4-in1k-20231106-8.pth', + input_size=(3, 256, 256), crop_pct=0.95), + 'vit_betwixt_patch16_reg4_gap_256.sbb_in12k': _cfg( + file='vit_betwixt_gap4-in12k-8.pth', + num_classes=11821, input_size=(3, 256, 256), crop_pct=0.95), 'vit_base_patch16_reg4_gap_256': _cfg( input_size=(3, 256, 256)), @@ -1906,6 +1933,14 @@ def vit_small_patch16_224(pretrained: bool = False, **kwargs) -> VisionTransform return model +@register_model +def vit_small_patch16_gap_224(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ ViT-Small (ViT-S/16) + """ + model_args = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, global_pool='avg', class_token=False, repr_size=True) + model = _create_vision_transformer('vit_small_patch16_224', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + @register_model def vit_small_patch16_384(pretrained: bool = False, **kwargs) -> VisionTransformer: """ ViT-Small (ViT-S/16) @@ -2755,10 +2790,21 @@ def vit_so400m_patch14_siglip_384(pretrained: bool = False, **kwargs) -> VisionT def vit_wee_patch16_reg1_gap_256(pretrained: bool = False, **kwargs) -> VisionTransformer: model_args = dict( patch_size=16, embed_dim=256, depth=14, num_heads=4, init_values=1e-5, mlp_ratio=5, + class_token=False, no_embed_class=True, reg_tokens=1, global_pool='avg', + ) + model = _create_vision_transformer( + 'vit_wee_patch16_reg1_gap_256', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_pwee_patch16_reg1_gap_256(pretrained: bool = False, **kwargs) -> VisionTransformer: + model_args = dict( + patch_size=16, embed_dim=256, depth=16, num_heads=4, init_values=1e-5, mlp_ratio=5, class_token=False, no_embed_class=True, reg_tokens=1, global_pool='avg', block_fn=ParallelScalingBlock, ) model = _create_vision_transformer( - 'vit_medium_patch16_reg1_gap_256', pretrained=pretrained, **dict(model_args, **kwargs)) + 'vit_pwee_patch16_reg1_gap_256', pretrained=pretrained, **dict(model_args, **kwargs)) return model @@ -2769,7 +2815,7 @@ def vit_little_patch16_reg4_gap_256(pretrained: bool = False, **kwargs) -> Visio class_token=False, no_embed_class=True, reg_tokens=4, global_pool='avg', ) model = _create_vision_transformer( - 'vit_medium_patch16_reg1_gap_256', pretrained=pretrained, **dict(model_args, **kwargs)) + 'vit_little_patch16_reg4_gap_256', pretrained=pretrained, **dict(model_args, **kwargs)) return model @@ -2795,6 +2841,17 @@ def vit_medium_patch16_reg4_gap_256(pretrained: bool = False, **kwargs) -> Visio return model +@register_model +def vit_mediumd_patch16_reg4_gap_256(pretrained: bool = False, **kwargs) -> VisionTransformer: + model_args = dict( + patch_size=16, embed_dim=512, depth=20, num_heads=8, init_values=1e-5, + class_token=False, no_embed_class=True, reg_tokens=4, global_pool='avg', + ) + model = _create_vision_transformer( + 'vit_mediumd_patch16_reg4_gap_256', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + @register_model def vit_betwixt_patch16_reg1_gap_256(pretrained: bool = False, **kwargs) -> VisionTransformer: model_args = dict( diff --git a/timm/version.py b/timm/version.py index 899e700f..c6092d3e 100644 --- a/timm/version.py +++ b/timm/version.py @@ -1 +1 @@ -__version__ = '1.0.0.dev0' +__version__ = '1.0.1.dev0'