diff --git a/timm/models/vision_transformer.py b/timm/models/vision_transformer.py index 2dbe6ff7..d7dabad6 100644 --- a/timm/models/vision_transformer.py +++ b/timm/models/vision_transformer.py @@ -42,9 +42,9 @@ from torch.jit import Final from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD, \ OPENAI_CLIP_MEAN, OPENAI_CLIP_STD -from timm.layers import PatchEmbed, Mlp, DropPath, AttentionPoolLatent, RmsNorm, PatchDropout, SwiGLUPacked, \ +from timm.layers import PatchEmbed, Mlp, DropPath, AttentionPoolLatent, RmsNorm, PatchDropout, SwiGLUPacked, SwiGLU, \ trunc_normal_, lecun_normal_, resample_patch_embed, resample_abs_pos_embed, use_fused_attn, \ - SwiGLU, get_act_layer, get_norm_layer, LayerType + get_act_layer, get_norm_layer, LayerType from ._builder import build_model_with_cfg from ._features import feature_take_indices from ._manipulate import named_apply, checkpoint_seq, adapt_input_conv @@ -1159,13 +1159,16 @@ def _convert_aimv2( k = k.replace('trunk.', '') k = k.replace('post_trunk_norm.', 'norm.') - if 'mlp.fc1' in k: - if k in out_dict: - v = torch.cat([v, out_dict[k]], dim=0) - elif 'mlp.fc3' in k: - k = k.replace('mlp.fc3', 'mlp.fc1') - if k in out_dict: - v = torch.cat([out_dict[k], v], dim=0) + # packed ver, FIXME to delete + # if 'mlp.fc1' in k: + # if k in out_dict: + # v = torch.cat([v, out_dict[k]], dim=0) + # elif 'mlp.fc3' in k: + # k = k.replace('mlp.fc3', 'mlp.fc1') + # if k in out_dict: + # v = torch.cat([out_dict[k], v], dim=0) + k = k.replace('mlp.fc1', 'mlp.fc1_g') + k = k.replace('mlp.fc3', 'mlp.fc1_x') out_dict[k] = v @@ -1682,18 +1685,27 @@ default_cfgs = { 'vit_base_patch16_clip_224.dfn2b': _cfg( hf_hub_id='timm/', + license='apple-ascl', mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=512), + 'vit_large_patch14_clip_224.dfn2b_s39b': _cfg( + #hf_hub_id='timm/', + hf_hub_id='apple/DFN2B-CLIP-ViT-L-14-39B', hf_hub_filename='open_clip_pytorch_model.bin', + license='apple-ascl', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=768), 'vit_large_patch14_clip_224.dfn2b': _cfg( hf_hub_id='timm/', + license='apple-ascl', notes=('natively QuickGELU, use quickgelu model variant for original results',), mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=768), 'vit_huge_patch14_clip_224.dfn5b': _cfg( hf_hub_id='timm/', + license='apple-ascl', notes=('natively QuickGELU, use quickgelu model variant for original results',), mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=1024), 'vit_huge_patch14_clip_378.dfn5b': _cfg( hf_hub_id='timm/', mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, + license='apple-ascl', notes=('natively QuickGELU, use quickgelu model variant for original results',), crop_pct=1.0, input_size=(3, 378, 378), num_classes=1024), @@ -2164,11 +2176,62 @@ default_cfgs = { input_size=(3, 448, 448), crop_pct=1.0, num_classes=0, ), - 'vit_large_patch14_aimv2_224': _cfg( + 'aimv2_large_patch14_224.apple_pt': _cfg( hf_hub_id='apple/aimv2-large-patch14-224', - mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, - input_size=(3, 224, 224), crop_pct=1.0, - num_classes=0), + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, license='apple-ascl', + crop_pct=1.0, num_classes=0), + 'aimv2_large_patch14_224.apple_pt_dist': _cfg( + hf_hub_id='apple/aimv2-large-patch14-224-distilled', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, license='apple-ascl', + crop_pct=1.0, num_classes=0), + 'aimv2_huge_patch14_224.apple_pt': _cfg( + hf_hub_id='apple/aimv2-huge-patch14-224', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, license='apple-ascl', + crop_pct=1.0, num_classes=0), + 'aimv2_1b_patch14_224.apple_pt': _cfg( + hf_hub_id='apple/aimv2-1b-patch14-224', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, license='apple-ascl', + crop_pct=1.0, num_classes=0), + 'aimv2_3b_patch14_224.apple_pt': _cfg( + hf_hub_id='apple/aimv2-3b-patch14-224', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, license='apple-ascl', + crop_pct=1.0, num_classes=0), + 'aimv2_large_patch14_336.apple_pt': _cfg( + hf_hub_id='apple/aimv2-large-patch14-336', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, license='apple-ascl', + input_size=(3, 336, 336), crop_pct=1.0, num_classes=0), + 'aimv2_large_patch14_336.apple_pt_dist': _cfg( + hf_hub_id='apple/aimv2-large-patch14-336-distilled', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, license='apple-ascl', + input_size=(3, 336, 336), crop_pct=1.0, num_classes=0), + 'aimv2_huge_patch14_336.apple_pt': _cfg( + hf_hub_id='apple/aimv2-huge-patch14-336', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, license='apple-ascl', + input_size=(3, 336, 336), crop_pct=1.0, num_classes=0), + 'aimv2_1b_patch14_336.apple_pt': _cfg( + hf_hub_id='apple/aimv2-1b-patch14-336', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, license='apple-ascl', + input_size=(3, 336, 336), crop_pct=1.0, num_classes=0), + 'aimv2_3b_patch14_336.apple_pt': _cfg( + hf_hub_id='apple/aimv2-3b-patch14-336', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, license='apple-ascl', + input_size=(3, 336, 336), crop_pct=1.0, num_classes=0), + 'aimv2_large_patch14_448.apple_pt': _cfg( + hf_hub_id='apple/aimv2-large-patch14-448', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, license='apple-ascl', + input_size=(3, 448, 448), crop_pct=1.0, num_classes=0), + 'aimv2_huge_patch14_448.apple_pt': _cfg( + hf_hub_id='apple/aimv2-huge-patch14-448', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, license='apple-ascl', + input_size=(3, 448, 448), crop_pct=1.0, num_classes=0), + 'aimv2_1b_patch14_448.apple_pt': _cfg( + hf_hub_id='apple/aimv2-1b-patch14-448', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, license='apple-ascl', + input_size=(3, 448, 448), crop_pct=1.0, num_classes=0), + 'aimv2_3b_patch14_448.apple_pt': _cfg( + hf_hub_id='apple/aimv2-3b-patch14-448', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, license='apple-ascl', + input_size=(3, 448, 448), crop_pct=1.0, num_classes=0), 'test_vit.r160_in1k': _cfg( hf_hub_id='timm/', @@ -3442,17 +3505,171 @@ def vit_intern300m_patch14_448(pretrained: bool = False, **kwargs) -> VisionTran @register_model -def vit_large_patch14_aimv2_224(pretrained: bool = False, **kwargs) -> VisionTransformer: - """ ViT-Large model (ViT-L/14) w/ parallel blocks and qk norm enabled. +def aimv2_large_patch14_224(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ ViT Large AIM-v2 model """ - rms_norm = partial(RmsNorm, eps=1e-5) model_args = dict( - patch_size=14, embed_dim=1024, depth=24, num_heads=16, class_token=False, fc_norm=False, - mlp_ratio=5.5, global_pool='avg', norm_layer=rms_norm, embed_norm_layer=rms_norm, mlp_layer=SwiGLUPacked, - qkv_bias=False, proj_bias=False, act_layer='silu' + patch_size=14, embed_dim=1024, depth=24, num_heads=8, class_token=False, fc_norm=False, + mlp_ratio=2.75, global_pool='avg', qkv_bias=False, proj_bias=False, act_layer='silu', + norm_layer=partial(RmsNorm, eps=1e-5), embed_norm_layer=partial(RmsNorm, eps=1e-5), mlp_layer=SwiGLU, ) model = _create_vision_transformer( - 'vit_large_patch14_aimv2_224', pretrained=pretrained, **dict(model_args, **kwargs)) + 'aimv2_large_patch14_224', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def aimv2_huge_patch14_224(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ ViT Huge AIM-v2 model + """ + + model_args = dict( + patch_size=14, embed_dim=1536, depth=24, num_heads=12, class_token=False, fc_norm=False, + mlp_ratio=2.6667, global_pool='avg', qkv_bias=False, proj_bias=False, act_layer='silu', + norm_layer=partial(RmsNorm, eps=1e-5), embed_norm_layer=partial(RmsNorm, eps=1e-5), mlp_layer=SwiGLU, + ) + model = _create_vision_transformer( + 'aimv2_huge_patch14_224', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def aimv2_1b_patch14_224(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ ViT 1B AIM-v2 model + """ + model_args = dict( + patch_size=14, embed_dim=2048, depth=24, num_heads=16, class_token=False, fc_norm=False, + mlp_ratio=2.75, global_pool='avg', qkv_bias=False, proj_bias=False, act_layer='silu', + norm_layer=partial(RmsNorm, eps=1e-5), embed_norm_layer=partial(RmsNorm, eps=1e-5), mlp_layer=SwiGLU, + ) + model = _create_vision_transformer( + 'aimv2_1b_patch14_224', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def aimv2_3b_patch14_224(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ ViT 3B AIM-v2 model + """ + model_args = dict( + patch_size=14, embed_dim=3072, depth=24, num_heads=24, class_token=False, fc_norm=False, + mlp_ratio=2.6667, global_pool='avg', qkv_bias=False, proj_bias=False, act_layer='silu', + norm_layer=partial(RmsNorm, eps=1e-5), embed_norm_layer=partial(RmsNorm, eps=1e-5), mlp_layer=SwiGLU, + ) + model = _create_vision_transformer( + 'aimv2_3b_patch14_224', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def aimv2_large_patch14_336(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ ViT Large AIM-v2 model + """ + model_args = dict( + patch_size=14, embed_dim=1024, depth=24, num_heads=8, class_token=False, fc_norm=False, + mlp_ratio=2.75, global_pool='avg', qkv_bias=False, proj_bias=False, act_layer='silu', + norm_layer=partial(RmsNorm, eps=1e-5), embed_norm_layer=partial(RmsNorm, eps=1e-5), mlp_layer=SwiGLU, + ) + model = _create_vision_transformer( + 'aimv2_large_patch14_336', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def aimv2_huge_patch14_336(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ ViT Huge AIM-v2 model + """ + model_args = dict( + patch_size=14, embed_dim=1536, depth=24, num_heads=12, class_token=False, fc_norm=False, + mlp_ratio=2.6667, global_pool='avg', qkv_bias=False, proj_bias=False, act_layer='silu', + norm_layer=partial(RmsNorm, eps=1e-5), embed_norm_layer=partial(RmsNorm, eps=1e-5), mlp_layer=SwiGLU, + ) + model = _create_vision_transformer( + 'aimv2_huge_patch14_336', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def aimv2_1b_patch14_336(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ ViT 1B AIM-v2 model + """ + model_args = dict( + patch_size=14, embed_dim=2048, depth=24, num_heads=16, class_token=False, fc_norm=False, + mlp_ratio=2.75, global_pool='avg', qkv_bias=False, proj_bias=False, act_layer='silu', + norm_layer=partial(RmsNorm, eps=1e-5), embed_norm_layer=partial(RmsNorm, eps=1e-5), mlp_layer=SwiGLU, + ) + model = _create_vision_transformer( + 'aimv2_1b_patch14_336', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def aimv2_3b_patch14_336(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ ViT 3B AIM-v2 model + """ + model_args = dict( + patch_size=14, embed_dim=3072, depth=24, num_heads=24, class_token=False, fc_norm=False, + mlp_ratio=2.6667, global_pool='avg', qkv_bias=False, proj_bias=False, act_layer='silu', + norm_layer=partial(RmsNorm, eps=1e-5), embed_norm_layer=partial(RmsNorm, eps=1e-5), mlp_layer=SwiGLU, + ) + model = _create_vision_transformer( + 'aimv2_3b_patch14_336', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def aimv2_large_patch14_448(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ ViT Large AIM-v2 model + """ + model_args = dict( + patch_size=14, embed_dim=1024, depth=24, num_heads=8, class_token=False, fc_norm=False, + mlp_ratio=2.75, global_pool='avg', qkv_bias=False, proj_bias=False, act_layer='silu', + norm_layer=partial(RmsNorm, eps=1e-5), embed_norm_layer=partial(RmsNorm, eps=1e-5), mlp_layer=SwiGLU, + ) + model = _create_vision_transformer( + 'aimv2_large_patch14_448', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def aimv2_huge_patch14_448(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ ViT Huge AIM-v2 model + """ + model_args = dict( + patch_size=14, embed_dim=1536, depth=24, num_heads=12, class_token=False, fc_norm=False, + mlp_ratio=2.6667, global_pool='avg', qkv_bias=False, proj_bias=False, act_layer='silu', + norm_layer=partial(RmsNorm, eps=1e-5), embed_norm_layer=partial(RmsNorm, eps=1e-5), mlp_layer=SwiGLU, + ) + model = _create_vision_transformer( + 'aimv2_huge_patch14_448', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def aimv2_1b_patch14_448(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ ViT 1B AIM-v2 model + """ + model_args = dict( + patch_size=14, embed_dim=2048, depth=24, num_heads=16, class_token=False, fc_norm=False, + mlp_ratio=2.75, global_pool='avg', qkv_bias=False, proj_bias=False, act_layer='silu', + norm_layer=partial(RmsNorm, eps=1e-5), embed_norm_layer=partial(RmsNorm, eps=1e-5), mlp_layer=SwiGLU, + ) + model = _create_vision_transformer( + 'aimv2_1b_patch14_448', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def aimv2_3b_patch14_448(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ ViT 3B AIM-v2 model + """ + model_args = dict( + patch_size=14, embed_dim=3072, depth=24, num_heads=24, class_token=False, fc_norm=False, + mlp_ratio=2.6667, global_pool='avg', qkv_bias=False, proj_bias=False, act_layer='silu', + norm_layer=partial(RmsNorm, eps=1e-5), embed_norm_layer=partial(RmsNorm, eps=1e-5), mlp_layer=SwiGLU, + ) + model = _create_vision_transformer( + 'aimv2_3b_patch14_448', pretrained=pretrained, **dict(model_args, **kwargs)) return model @@ -3487,6 +3704,19 @@ def test_vit3(pretrained: bool = False, **kwargs) -> VisionTransformer: return model +@register_model +def test_vit4(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ ViT Test + """ + model_args = dict( + patch_size=16, embed_dim=96, depth=9, num_heads=3, mlp_ratio=3, + class_token=False, reg_tokens=1, global_pool='avg', init_values=1e-5, dynamic_img_size=True, + norm_layer='rmsnorm', + ) + model = _create_vision_transformer('test_vit4', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + register_model_deprecations(__name__, { 'vit_tiny_patch16_224_in21k': 'vit_tiny_patch16_224.augreg_in21k', 'vit_small_patch32_224_in21k': 'vit_small_patch32_224.augreg_in21k',