Add (almost) full set of aimv2 model instances. Switch back to unpacked SwiGLU. Verify correctness. Add DFN L/14 39B weight.

This commit is contained in:
Ross Wightman 2024-12-30 14:23:20 -08:00
parent a4146b79d1
commit 1d6ebeb102

View File

@ -42,9 +42,9 @@ from torch.jit import Final
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD, \
OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
from timm.layers import PatchEmbed, Mlp, DropPath, AttentionPoolLatent, RmsNorm, PatchDropout, SwiGLUPacked, \
from timm.layers import PatchEmbed, Mlp, DropPath, AttentionPoolLatent, RmsNorm, PatchDropout, SwiGLUPacked, SwiGLU, \
trunc_normal_, lecun_normal_, resample_patch_embed, resample_abs_pos_embed, use_fused_attn, \
SwiGLU, get_act_layer, get_norm_layer, LayerType
get_act_layer, get_norm_layer, LayerType
from ._builder import build_model_with_cfg
from ._features import feature_take_indices
from ._manipulate import named_apply, checkpoint_seq, adapt_input_conv
@ -1159,13 +1159,16 @@ def _convert_aimv2(
k = k.replace('trunk.', '')
k = k.replace('post_trunk_norm.', 'norm.')
if 'mlp.fc1' in k:
if k in out_dict:
v = torch.cat([v, out_dict[k]], dim=0)
elif 'mlp.fc3' in k:
k = k.replace('mlp.fc3', 'mlp.fc1')
if k in out_dict:
v = torch.cat([out_dict[k], v], dim=0)
# packed ver, FIXME to delete
# if 'mlp.fc1' in k:
# if k in out_dict:
# v = torch.cat([v, out_dict[k]], dim=0)
# elif 'mlp.fc3' in k:
# k = k.replace('mlp.fc3', 'mlp.fc1')
# if k in out_dict:
# v = torch.cat([out_dict[k], v], dim=0)
k = k.replace('mlp.fc1', 'mlp.fc1_g')
k = k.replace('mlp.fc3', 'mlp.fc1_x')
out_dict[k] = v
@ -1682,18 +1685,27 @@ default_cfgs = {
'vit_base_patch16_clip_224.dfn2b': _cfg(
hf_hub_id='timm/',
license='apple-ascl',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=512),
'vit_large_patch14_clip_224.dfn2b_s39b': _cfg(
#hf_hub_id='timm/',
hf_hub_id='apple/DFN2B-CLIP-ViT-L-14-39B', hf_hub_filename='open_clip_pytorch_model.bin',
license='apple-ascl',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=768),
'vit_large_patch14_clip_224.dfn2b': _cfg(
hf_hub_id='timm/',
license='apple-ascl',
notes=('natively QuickGELU, use quickgelu model variant for original results',),
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=768),
'vit_huge_patch14_clip_224.dfn5b': _cfg(
hf_hub_id='timm/',
license='apple-ascl',
notes=('natively QuickGELU, use quickgelu model variant for original results',),
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=1024),
'vit_huge_patch14_clip_378.dfn5b': _cfg(
hf_hub_id='timm/',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
license='apple-ascl',
notes=('natively QuickGELU, use quickgelu model variant for original results',),
crop_pct=1.0, input_size=(3, 378, 378), num_classes=1024),
@ -2164,11 +2176,62 @@ default_cfgs = {
input_size=(3, 448, 448), crop_pct=1.0, num_classes=0,
),
'vit_large_patch14_aimv2_224': _cfg(
'aimv2_large_patch14_224.apple_pt': _cfg(
hf_hub_id='apple/aimv2-large-patch14-224',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
input_size=(3, 224, 224), crop_pct=1.0,
num_classes=0),
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, license='apple-ascl',
crop_pct=1.0, num_classes=0),
'aimv2_large_patch14_224.apple_pt_dist': _cfg(
hf_hub_id='apple/aimv2-large-patch14-224-distilled',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, license='apple-ascl',
crop_pct=1.0, num_classes=0),
'aimv2_huge_patch14_224.apple_pt': _cfg(
hf_hub_id='apple/aimv2-huge-patch14-224',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, license='apple-ascl',
crop_pct=1.0, num_classes=0),
'aimv2_1b_patch14_224.apple_pt': _cfg(
hf_hub_id='apple/aimv2-1b-patch14-224',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, license='apple-ascl',
crop_pct=1.0, num_classes=0),
'aimv2_3b_patch14_224.apple_pt': _cfg(
hf_hub_id='apple/aimv2-3b-patch14-224',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, license='apple-ascl',
crop_pct=1.0, num_classes=0),
'aimv2_large_patch14_336.apple_pt': _cfg(
hf_hub_id='apple/aimv2-large-patch14-336',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, license='apple-ascl',
input_size=(3, 336, 336), crop_pct=1.0, num_classes=0),
'aimv2_large_patch14_336.apple_pt_dist': _cfg(
hf_hub_id='apple/aimv2-large-patch14-336-distilled',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, license='apple-ascl',
input_size=(3, 336, 336), crop_pct=1.0, num_classes=0),
'aimv2_huge_patch14_336.apple_pt': _cfg(
hf_hub_id='apple/aimv2-huge-patch14-336',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, license='apple-ascl',
input_size=(3, 336, 336), crop_pct=1.0, num_classes=0),
'aimv2_1b_patch14_336.apple_pt': _cfg(
hf_hub_id='apple/aimv2-1b-patch14-336',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, license='apple-ascl',
input_size=(3, 336, 336), crop_pct=1.0, num_classes=0),
'aimv2_3b_patch14_336.apple_pt': _cfg(
hf_hub_id='apple/aimv2-3b-patch14-336',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, license='apple-ascl',
input_size=(3, 336, 336), crop_pct=1.0, num_classes=0),
'aimv2_large_patch14_448.apple_pt': _cfg(
hf_hub_id='apple/aimv2-large-patch14-448',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, license='apple-ascl',
input_size=(3, 448, 448), crop_pct=1.0, num_classes=0),
'aimv2_huge_patch14_448.apple_pt': _cfg(
hf_hub_id='apple/aimv2-huge-patch14-448',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, license='apple-ascl',
input_size=(3, 448, 448), crop_pct=1.0, num_classes=0),
'aimv2_1b_patch14_448.apple_pt': _cfg(
hf_hub_id='apple/aimv2-1b-patch14-448',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, license='apple-ascl',
input_size=(3, 448, 448), crop_pct=1.0, num_classes=0),
'aimv2_3b_patch14_448.apple_pt': _cfg(
hf_hub_id='apple/aimv2-3b-patch14-448',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, license='apple-ascl',
input_size=(3, 448, 448), crop_pct=1.0, num_classes=0),
'test_vit.r160_in1k': _cfg(
hf_hub_id='timm/',
@ -3442,17 +3505,171 @@ def vit_intern300m_patch14_448(pretrained: bool = False, **kwargs) -> VisionTran
@register_model
def vit_large_patch14_aimv2_224(pretrained: bool = False, **kwargs) -> VisionTransformer:
""" ViT-Large model (ViT-L/14) w/ parallel blocks and qk norm enabled.
def aimv2_large_patch14_224(pretrained: bool = False, **kwargs) -> VisionTransformer:
""" ViT Large AIM-v2 model
"""
rms_norm = partial(RmsNorm, eps=1e-5)
model_args = dict(
patch_size=14, embed_dim=1024, depth=24, num_heads=16, class_token=False, fc_norm=False,
mlp_ratio=5.5, global_pool='avg', norm_layer=rms_norm, embed_norm_layer=rms_norm, mlp_layer=SwiGLUPacked,
qkv_bias=False, proj_bias=False, act_layer='silu'
patch_size=14, embed_dim=1024, depth=24, num_heads=8, class_token=False, fc_norm=False,
mlp_ratio=2.75, global_pool='avg', qkv_bias=False, proj_bias=False, act_layer='silu',
norm_layer=partial(RmsNorm, eps=1e-5), embed_norm_layer=partial(RmsNorm, eps=1e-5), mlp_layer=SwiGLU,
)
model = _create_vision_transformer(
'vit_large_patch14_aimv2_224', pretrained=pretrained, **dict(model_args, **kwargs))
'aimv2_large_patch14_224', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@register_model
def aimv2_huge_patch14_224(pretrained: bool = False, **kwargs) -> VisionTransformer:
""" ViT Huge AIM-v2 model
"""
model_args = dict(
patch_size=14, embed_dim=1536, depth=24, num_heads=12, class_token=False, fc_norm=False,
mlp_ratio=2.6667, global_pool='avg', qkv_bias=False, proj_bias=False, act_layer='silu',
norm_layer=partial(RmsNorm, eps=1e-5), embed_norm_layer=partial(RmsNorm, eps=1e-5), mlp_layer=SwiGLU,
)
model = _create_vision_transformer(
'aimv2_huge_patch14_224', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@register_model
def aimv2_1b_patch14_224(pretrained: bool = False, **kwargs) -> VisionTransformer:
""" ViT 1B AIM-v2 model
"""
model_args = dict(
patch_size=14, embed_dim=2048, depth=24, num_heads=16, class_token=False, fc_norm=False,
mlp_ratio=2.75, global_pool='avg', qkv_bias=False, proj_bias=False, act_layer='silu',
norm_layer=partial(RmsNorm, eps=1e-5), embed_norm_layer=partial(RmsNorm, eps=1e-5), mlp_layer=SwiGLU,
)
model = _create_vision_transformer(
'aimv2_1b_patch14_224', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@register_model
def aimv2_3b_patch14_224(pretrained: bool = False, **kwargs) -> VisionTransformer:
""" ViT 3B AIM-v2 model
"""
model_args = dict(
patch_size=14, embed_dim=3072, depth=24, num_heads=24, class_token=False, fc_norm=False,
mlp_ratio=2.6667, global_pool='avg', qkv_bias=False, proj_bias=False, act_layer='silu',
norm_layer=partial(RmsNorm, eps=1e-5), embed_norm_layer=partial(RmsNorm, eps=1e-5), mlp_layer=SwiGLU,
)
model = _create_vision_transformer(
'aimv2_3b_patch14_224', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@register_model
def aimv2_large_patch14_336(pretrained: bool = False, **kwargs) -> VisionTransformer:
""" ViT Large AIM-v2 model
"""
model_args = dict(
patch_size=14, embed_dim=1024, depth=24, num_heads=8, class_token=False, fc_norm=False,
mlp_ratio=2.75, global_pool='avg', qkv_bias=False, proj_bias=False, act_layer='silu',
norm_layer=partial(RmsNorm, eps=1e-5), embed_norm_layer=partial(RmsNorm, eps=1e-5), mlp_layer=SwiGLU,
)
model = _create_vision_transformer(
'aimv2_large_patch14_336', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@register_model
def aimv2_huge_patch14_336(pretrained: bool = False, **kwargs) -> VisionTransformer:
""" ViT Huge AIM-v2 model
"""
model_args = dict(
patch_size=14, embed_dim=1536, depth=24, num_heads=12, class_token=False, fc_norm=False,
mlp_ratio=2.6667, global_pool='avg', qkv_bias=False, proj_bias=False, act_layer='silu',
norm_layer=partial(RmsNorm, eps=1e-5), embed_norm_layer=partial(RmsNorm, eps=1e-5), mlp_layer=SwiGLU,
)
model = _create_vision_transformer(
'aimv2_huge_patch14_336', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@register_model
def aimv2_1b_patch14_336(pretrained: bool = False, **kwargs) -> VisionTransformer:
""" ViT 1B AIM-v2 model
"""
model_args = dict(
patch_size=14, embed_dim=2048, depth=24, num_heads=16, class_token=False, fc_norm=False,
mlp_ratio=2.75, global_pool='avg', qkv_bias=False, proj_bias=False, act_layer='silu',
norm_layer=partial(RmsNorm, eps=1e-5), embed_norm_layer=partial(RmsNorm, eps=1e-5), mlp_layer=SwiGLU,
)
model = _create_vision_transformer(
'aimv2_1b_patch14_336', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@register_model
def aimv2_3b_patch14_336(pretrained: bool = False, **kwargs) -> VisionTransformer:
""" ViT 3B AIM-v2 model
"""
model_args = dict(
patch_size=14, embed_dim=3072, depth=24, num_heads=24, class_token=False, fc_norm=False,
mlp_ratio=2.6667, global_pool='avg', qkv_bias=False, proj_bias=False, act_layer='silu',
norm_layer=partial(RmsNorm, eps=1e-5), embed_norm_layer=partial(RmsNorm, eps=1e-5), mlp_layer=SwiGLU,
)
model = _create_vision_transformer(
'aimv2_3b_patch14_336', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@register_model
def aimv2_large_patch14_448(pretrained: bool = False, **kwargs) -> VisionTransformer:
""" ViT Large AIM-v2 model
"""
model_args = dict(
patch_size=14, embed_dim=1024, depth=24, num_heads=8, class_token=False, fc_norm=False,
mlp_ratio=2.75, global_pool='avg', qkv_bias=False, proj_bias=False, act_layer='silu',
norm_layer=partial(RmsNorm, eps=1e-5), embed_norm_layer=partial(RmsNorm, eps=1e-5), mlp_layer=SwiGLU,
)
model = _create_vision_transformer(
'aimv2_large_patch14_448', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@register_model
def aimv2_huge_patch14_448(pretrained: bool = False, **kwargs) -> VisionTransformer:
""" ViT Huge AIM-v2 model
"""
model_args = dict(
patch_size=14, embed_dim=1536, depth=24, num_heads=12, class_token=False, fc_norm=False,
mlp_ratio=2.6667, global_pool='avg', qkv_bias=False, proj_bias=False, act_layer='silu',
norm_layer=partial(RmsNorm, eps=1e-5), embed_norm_layer=partial(RmsNorm, eps=1e-5), mlp_layer=SwiGLU,
)
model = _create_vision_transformer(
'aimv2_huge_patch14_448', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@register_model
def aimv2_1b_patch14_448(pretrained: bool = False, **kwargs) -> VisionTransformer:
""" ViT 1B AIM-v2 model
"""
model_args = dict(
patch_size=14, embed_dim=2048, depth=24, num_heads=16, class_token=False, fc_norm=False,
mlp_ratio=2.75, global_pool='avg', qkv_bias=False, proj_bias=False, act_layer='silu',
norm_layer=partial(RmsNorm, eps=1e-5), embed_norm_layer=partial(RmsNorm, eps=1e-5), mlp_layer=SwiGLU,
)
model = _create_vision_transformer(
'aimv2_1b_patch14_448', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@register_model
def aimv2_3b_patch14_448(pretrained: bool = False, **kwargs) -> VisionTransformer:
""" ViT 3B AIM-v2 model
"""
model_args = dict(
patch_size=14, embed_dim=3072, depth=24, num_heads=24, class_token=False, fc_norm=False,
mlp_ratio=2.6667, global_pool='avg', qkv_bias=False, proj_bias=False, act_layer='silu',
norm_layer=partial(RmsNorm, eps=1e-5), embed_norm_layer=partial(RmsNorm, eps=1e-5), mlp_layer=SwiGLU,
)
model = _create_vision_transformer(
'aimv2_3b_patch14_448', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@ -3487,6 +3704,19 @@ def test_vit3(pretrained: bool = False, **kwargs) -> VisionTransformer:
return model
@register_model
def test_vit4(pretrained: bool = False, **kwargs) -> VisionTransformer:
""" ViT Test
"""
model_args = dict(
patch_size=16, embed_dim=96, depth=9, num_heads=3, mlp_ratio=3,
class_token=False, reg_tokens=1, global_pool='avg', init_values=1e-5, dynamic_img_size=True,
norm_layer='rmsnorm',
)
model = _create_vision_transformer('test_vit4', pretrained=pretrained, **dict(model_args, **kwargs))
return model
register_model_deprecations(__name__, {
'vit_tiny_patch16_224_in21k': 'vit_tiny_patch16_224.augreg_in21k',
'vit_small_patch32_224_in21k': 'vit_small_patch32_224.augreg_in21k',