Add XXS models

pull/91/head
Hugo Touvron 2021-04-14 09:07:53 +02:00 committed by GitHub
parent 542e05e021
commit cb29b5efd5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 87 additions and 1 deletions

View File

@ -13,7 +13,8 @@ from timm.models.layers import trunc_normal_
__all__ = [
'cait_M48', 'cait_M36', 'cait_M4',
'cait_S36', 'cait_S24','cait_S24_224',
'cait_XS24'
'cait_XS24','cait_XXS24','cait_XXS24_224',
'cait_XXS36','cait_XXS36_224'
]
@ -251,8 +252,93 @@ class cait_models(nn.Module):
return x
@register_model
def cait_XXS24_224(pretrained=False, **kwargs):
model = cait_models(
img_size= 224,patch_size=16, embed_dim=192, depth=24, num_heads=4, mlp_ratio=4, qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
init_scale=1e-5,
depth_token_only=2,**kwargs)
model.default_cfg = _cfg()
if pretrained:
checkpoint = torch.hub.load_state_dict_from_url(
url="https://dl.fbaipublicfiles.com/deit/XXS24_224.pth",
map_location="cpu", check_hash=True
)
checkpoint_no_module = {}
for k in model.state_dict().keys():
checkpoint_no_module[k] = checkpoint["model"]['module.'+k]
model.load_state_dict(checkpoint_no_module)
return model
@register_model
def cait_XXS24(pretrained=False, **kwargs):
model = cait_models(
img_size= 384,patch_size=16, embed_dim=192, depth=24, num_heads=4, mlp_ratio=4, qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
init_scale=1e-5,
depth_token_only=2,**kwargs)
model.default_cfg = _cfg()
if pretrained:
checkpoint = torch.hub.load_state_dict_from_url(
url="https://dl.fbaipublicfiles.com/deit/XXS24_384.pth",
map_location="cpu", check_hash=True
)
checkpoint_no_module = {}
for k in model.state_dict().keys():
checkpoint_no_module[k] = checkpoint["model"]['module.'+k]
model.load_state_dict(checkpoint_no_module)
return model
@register_model
def cait_XXS36_224(pretrained=False, **kwargs):
model = cait_models(
img_size= 224,patch_size=16, embed_dim=192, depth=36, num_heads=4, mlp_ratio=4, qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
init_scale=1e-5,
depth_token_only=2,**kwargs)
model.default_cfg = _cfg()
if pretrained:
checkpoint = torch.hub.load_state_dict_from_url(
url="https://dl.fbaipublicfiles.com/deit/XXS36_224.pth",
map_location="cpu", check_hash=True
)
checkpoint_no_module = {}
for k in model.state_dict().keys():
checkpoint_no_module[k] = checkpoint["model"]['module.'+k]
model.load_state_dict(checkpoint_no_module)
return model
@register_model
def cait_XXS36(pretrained=False, **kwargs):
model = cait_models(
img_size= 384,patch_size=16, embed_dim=192, depth=36, num_heads=4, mlp_ratio=4, qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
init_scale=1e-5,
depth_token_only=2,**kwargs)
model.default_cfg = _cfg()
if pretrained:
checkpoint = torch.hub.load_state_dict_from_url(
url="https://dl.fbaipublicfiles.com/deit/XXS36_384.pth",
map_location="cpu", check_hash=True
)
checkpoint_no_module = {}
for k in model.state_dict().keys():
checkpoint_no_module[k] = checkpoint["model"]['module.'+k]
model.load_state_dict(checkpoint_no_module)
return model
@register_model
def cait_XS24(pretrained=False, **kwargs):
model = cait_models(