Make quickgelu models appear in listing
parent
96bd162ddb
commit
dcfdba1f5f
|
@ -35,7 +35,6 @@ import torch.nn.functional as F
|
|||
import torch.utils.checkpoint
|
||||
from torch.jit import Final
|
||||
|
||||
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD, \
|
||||
OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
|
||||
from timm.layers import PatchEmbed, Mlp, DropPath, AttentionPoolLatent, RmsNorm, PatchDropout, SwiGLUPacked, \
|
||||
|
@ -1043,7 +1042,7 @@ def _cfg(url='', **kwargs):
|
|||
}
|
||||
|
||||
|
||||
default_cfgs = generate_default_cfgs({
|
||||
default_cfgs = {
|
||||
|
||||
# re-finetuned augreg 21k FT on in1k weights
|
||||
'vit_base_patch16_224.augreg2_in21k_ft_in1k': _cfg(
|
||||
|
@ -1459,49 +1458,60 @@ default_cfgs = generate_default_cfgs({
|
|||
'vit_large_patch14_clip_224.dfn2b': _cfg(
|
||||
hf_hub_id='apple/DFN2B-CLIP-ViT-L-14',
|
||||
hf_hub_filename='open_clip_pytorch_model.bin',
|
||||
notes=('natively QuickGELU, use quickgelu model variant for original results',),
|
||||
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=768),
|
||||
'vit_huge_patch14_clip_224.dfn5b': _cfg(
|
||||
hf_hub_id='apple/DFN5B-CLIP-ViT-H-14',
|
||||
hf_hub_filename='open_clip_pytorch_model.bin',
|
||||
notes=('natively QuickGELU, use quickgelu model variant for original results',),
|
||||
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=1024),
|
||||
'vit_huge_patch14_clip_378.dfn5b': _cfg(
|
||||
hf_hub_id='apple/DFN5B-CLIP-ViT-H-14-378',
|
||||
hf_hub_filename='open_clip_pytorch_model.bin',
|
||||
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
|
||||
notes=('natively QuickGELU, use quickgelu model variant for original results',),
|
||||
crop_pct=1.0, input_size=(3, 378, 378), num_classes=1024),
|
||||
|
||||
'vit_base_patch32_clip_224.metaclip_2pt5b': _cfg(
|
||||
hf_hub_id='facebook/metaclip-b32-fullcc2.5b',
|
||||
hf_hub_filename='metaclip_b32_fullcc2.5b.bin',
|
||||
license='cc-by-nc-4.0',
|
||||
notes=('natively QuickGELU, use quickgelu model variant for original results',),
|
||||
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=512),
|
||||
'vit_base_patch16_clip_224.metaclip_2pt5b': _cfg(
|
||||
hf_hub_id='facebook/metaclip-b16-fullcc2.5b',
|
||||
hf_hub_filename='metaclip_b16_fullcc2.5b.bin',
|
||||
license='cc-by-nc-4.0',
|
||||
notes=('natively QuickGELU, use quickgelu model variant for original results',),
|
||||
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=512),
|
||||
'vit_large_patch14_clip_224.metaclip_2pt5b': _cfg(
|
||||
hf_hub_id='facebook/metaclip-l14-fullcc2.5b',
|
||||
hf_hub_filename='metaclip_l14_fullcc2.5b.bin',
|
||||
license='cc-by-nc-4.0',
|
||||
notes=('natively QuickGELU, use quickgelu model variant for original results',),
|
||||
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=768),
|
||||
'vit_huge_patch14_clip_224.metaclip_2pt5b': _cfg(
|
||||
hf_hub_id='facebook/metaclip-h14-fullcc2.5b',
|
||||
hf_hub_filename='metaclip_h14_fullcc2.5b.bin',
|
||||
license='cc-by-nc-4.0',
|
||||
notes=('natively QuickGELU, use quickgelu model variant for original results',),
|
||||
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=1024),
|
||||
|
||||
'vit_base_patch32_clip_224.openai': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
notes=('natively QuickGELU, use quickgelu model variant for original results',),
|
||||
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=512),
|
||||
'vit_base_patch16_clip_224.openai': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
notes=('natively QuickGELU, use quickgelu model variant for original results',),
|
||||
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=512),
|
||||
'vit_large_patch14_clip_224.openai': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
notes=('natively QuickGELU, use quickgelu model variant for original results',),
|
||||
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=768),
|
||||
'vit_large_patch14_clip_336.openai': _cfg(
|
||||
hf_hub_id='timm/', hf_hub_filename='open_clip_pytorch_model.bin',
|
||||
notes=('natively QuickGELU, use quickgelu model variant for original results',),
|
||||
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
|
||||
crop_pct=1.0, input_size=(3, 336, 336), num_classes=768),
|
||||
|
||||
|
@ -1677,7 +1687,25 @@ default_cfgs = generate_default_cfgs({
|
|||
'vit_medium_patch16_reg4_gap_256': _cfg(
|
||||
input_size=(3, 256, 256)),
|
||||
'vit_base_patch16_reg8_gap_256': _cfg(input_size=(3, 256, 256)),
|
||||
}
|
||||
|
||||
_quick_gelu_cfgs = [
|
||||
'vit_large_patch14_clip_224.dfn2b',
|
||||
'vit_huge_patch14_clip_224.dfn5b',
|
||||
'vit_huge_patch14_clip_378.dfn5b',
|
||||
'vit_base_patch32_clip_224.metaclip_2pt5b',
|
||||
'vit_base_patch16_clip_224.metaclip_2pt5b',
|
||||
'vit_large_patch14_clip_224.metaclip_2pt5b',
|
||||
'vit_huge_patch14_clip_224.metaclip_2pt5b',
|
||||
'vit_base_patch32_clip_224.openai',
|
||||
'vit_base_patch16_clip_224.openai',
|
||||
'vit_large_patch14_clip_224.openai',
|
||||
'vit_large_patch14_clip_336.openai',
|
||||
]
|
||||
default_cfgs.update({
|
||||
n.replace('_clip_', '_clip_quickgelu_'): default_cfgs[n] for n in _quick_gelu_cfgs
|
||||
})
|
||||
default_cfgs = generate_default_cfgs(default_cfgs)
|
||||
|
||||
|
||||
def _create_vision_transformer(variant, pretrained=False, **kwargs):
|
||||
|
@ -2133,8 +2161,7 @@ def vit_base_patch32_clip_quickgelu_224(pretrained=False, **kwargs) -> VisionTra
|
|||
patch_size=32, embed_dim=768, depth=12, num_heads=12, pre_norm=True,
|
||||
norm_layer=nn.LayerNorm, act_layer='quick_gelu')
|
||||
model = _create_vision_transformer(
|
||||
'vit_base_patch32_clip_224', # map to non quickgelu pretrained_cfg intentionally
|
||||
pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
'vit_base_patch32_clip_quickgelu_224', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return model
|
||||
|
||||
|
||||
|
@ -2146,8 +2173,7 @@ def vit_base_patch16_clip_quickgelu_224(pretrained=False, **kwargs) -> VisionTra
|
|||
patch_size=16, embed_dim=768, depth=12, num_heads=12, pre_norm=True,
|
||||
norm_layer=nn.LayerNorm, act_layer='quick_gelu')
|
||||
model = _create_vision_transformer(
|
||||
'vit_base_patch16_clip_224', # map to non quickgelu pretrained_cfg intentionally
|
||||
pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
'vit_base_patch16_clip_quickgelu_224', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return model
|
||||
|
||||
|
||||
|
@ -2160,8 +2186,7 @@ def vit_large_patch14_clip_quickgelu_224(pretrained=False, **kwargs) -> VisionTr
|
|||
patch_size=14, embed_dim=1024, depth=24, num_heads=16, pre_norm=True,
|
||||
norm_layer=nn.LayerNorm, act_layer='quick_gelu')
|
||||
model = _create_vision_transformer(
|
||||
'vit_large_patch14_clip_224', # map to non quickgelu pretrained_cfg intentionally
|
||||
pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
'vit_large_patch14_clip_quickgelu_224', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return model
|
||||
|
||||
|
||||
|
@ -2173,8 +2198,7 @@ def vit_large_patch14_clip_quickgelu_336(pretrained=False, **kwargs) -> VisionTr
|
|||
patch_size=14, embed_dim=1024, depth=24, num_heads=16, pre_norm=True,
|
||||
norm_layer=nn.LayerNorm, act_layer='quick_gelu')
|
||||
model = _create_vision_transformer(
|
||||
'vit_large_patch14_clip_336', # map to non quickgelu pretrained_cfg intentionally
|
||||
pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
'vit_large_patch14_clip_quickgelu_336', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return model
|
||||
|
||||
|
||||
|
@ -2186,8 +2210,7 @@ def vit_huge_patch14_clip_quickgelu_224(pretrained=False, **kwargs) -> VisionTra
|
|||
patch_size=14, embed_dim=1280, depth=32, num_heads=16, pre_norm=True,
|
||||
norm_layer=nn.LayerNorm, act_layer='quick_gelu')
|
||||
model = _create_vision_transformer(
|
||||
'vit_huge_patch14_clip_224', # map to non quickgelu pretrained_cfg intentionally
|
||||
pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
'vit_huge_patch14_clip_quickgelu_224', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return model
|
||||
|
||||
|
||||
|
@ -2199,8 +2222,7 @@ def vit_huge_patch14_clip_quickgelu_378(pretrained=False, **kwargs) -> VisionTra
|
|||
patch_size=14, embed_dim=1280, depth=32, num_heads=16, pre_norm=True,
|
||||
norm_layer=nn.LayerNorm, act_layer='quick_gelu')
|
||||
model = _create_vision_transformer(
|
||||
'vit_huge_patch14_clip_378', # map to non quickgelu pretrained_cfg intentionally
|
||||
pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
'vit_huge_patch14_clip_quickgelu_378', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return model
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue