mirror of https://github.com/facebookresearch/deit
Add XXS models
parent
542e05e021
commit
cb29b5efd5
|
@ -13,7 +13,8 @@ from timm.models.layers import trunc_normal_
|
|||
__all__ = [
|
||||
'cait_M48', 'cait_M36', 'cait_M4',
|
||||
'cait_S36', 'cait_S24','cait_S24_224',
|
||||
'cait_XS24'
|
||||
'cait_XS24','cait_XXS24','cait_XXS24_224',
|
||||
'cait_XXS36','cait_XXS36_224'
|
||||
]
|
||||
|
||||
|
||||
|
@ -251,8 +252,93 @@ class cait_models(nn.Module):
|
|||
|
||||
return x
|
||||
|
||||
@register_model
|
||||
def cait_XXS24_224(pretrained=False, **kwargs):
|
||||
model = cait_models(
|
||||
img_size= 224,patch_size=16, embed_dim=192, depth=24, num_heads=4, mlp_ratio=4, qkv_bias=True,
|
||||
norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
||||
init_scale=1e-5,
|
||||
depth_token_only=2,**kwargs)
|
||||
|
||||
model.default_cfg = _cfg()
|
||||
if pretrained:
|
||||
checkpoint = torch.hub.load_state_dict_from_url(
|
||||
url="https://dl.fbaipublicfiles.com/deit/XXS24_224.pth",
|
||||
map_location="cpu", check_hash=True
|
||||
)
|
||||
checkpoint_no_module = {}
|
||||
for k in model.state_dict().keys():
|
||||
checkpoint_no_module[k] = checkpoint["model"]['module.'+k]
|
||||
|
||||
model.load_state_dict(checkpoint_no_module)
|
||||
|
||||
return model
|
||||
|
||||
@register_model
|
||||
def cait_XXS24(pretrained=False, **kwargs):
|
||||
model = cait_models(
|
||||
img_size= 384,patch_size=16, embed_dim=192, depth=24, num_heads=4, mlp_ratio=4, qkv_bias=True,
|
||||
norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
||||
init_scale=1e-5,
|
||||
depth_token_only=2,**kwargs)
|
||||
|
||||
model.default_cfg = _cfg()
|
||||
if pretrained:
|
||||
checkpoint = torch.hub.load_state_dict_from_url(
|
||||
url="https://dl.fbaipublicfiles.com/deit/XXS24_384.pth",
|
||||
map_location="cpu", check_hash=True
|
||||
)
|
||||
checkpoint_no_module = {}
|
||||
for k in model.state_dict().keys():
|
||||
checkpoint_no_module[k] = checkpoint["model"]['module.'+k]
|
||||
|
||||
model.load_state_dict(checkpoint_no_module)
|
||||
|
||||
return model
|
||||
@register_model
|
||||
def cait_XXS36_224(pretrained=False, **kwargs):
|
||||
model = cait_models(
|
||||
img_size= 224,patch_size=16, embed_dim=192, depth=36, num_heads=4, mlp_ratio=4, qkv_bias=True,
|
||||
norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
||||
init_scale=1e-5,
|
||||
depth_token_only=2,**kwargs)
|
||||
|
||||
model.default_cfg = _cfg()
|
||||
if pretrained:
|
||||
checkpoint = torch.hub.load_state_dict_from_url(
|
||||
url="https://dl.fbaipublicfiles.com/deit/XXS36_224.pth",
|
||||
map_location="cpu", check_hash=True
|
||||
)
|
||||
checkpoint_no_module = {}
|
||||
for k in model.state_dict().keys():
|
||||
checkpoint_no_module[k] = checkpoint["model"]['module.'+k]
|
||||
|
||||
model.load_state_dict(checkpoint_no_module)
|
||||
|
||||
return model
|
||||
|
||||
@register_model
|
||||
def cait_XXS36(pretrained=False, **kwargs):
|
||||
model = cait_models(
|
||||
img_size= 384,patch_size=16, embed_dim=192, depth=36, num_heads=4, mlp_ratio=4, qkv_bias=True,
|
||||
norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
||||
init_scale=1e-5,
|
||||
depth_token_only=2,**kwargs)
|
||||
|
||||
model.default_cfg = _cfg()
|
||||
if pretrained:
|
||||
checkpoint = torch.hub.load_state_dict_from_url(
|
||||
url="https://dl.fbaipublicfiles.com/deit/XXS36_384.pth",
|
||||
map_location="cpu", check_hash=True
|
||||
)
|
||||
checkpoint_no_module = {}
|
||||
for k in model.state_dict().keys():
|
||||
checkpoint_no_module[k] = checkpoint["model"]['module.'+k]
|
||||
|
||||
model.load_state_dict(checkpoint_no_module)
|
||||
|
||||
return model
|
||||
|
||||
@register_model
|
||||
def cait_XS24(pretrained=False, **kwargs):
|
||||
model = cait_models(
|
||||
|
|
Loading…
Reference in New Issue