mirror of
https://github.com/facebookresearch/deit.git
synced 2025-06-03 14:52:20 +08:00
Add XXS models
This commit is contained in:
parent
542e05e021
commit
cb29b5efd5
@ -13,7 +13,8 @@ from timm.models.layers import trunc_normal_
|
|||||||
__all__ = [
|
__all__ = [
|
||||||
'cait_M48', 'cait_M36', 'cait_M4',
|
'cait_M48', 'cait_M36', 'cait_M4',
|
||||||
'cait_S36', 'cait_S24','cait_S24_224',
|
'cait_S36', 'cait_S24','cait_S24_224',
|
||||||
'cait_XS24'
|
'cait_XS24','cait_XXS24','cait_XXS24_224',
|
||||||
|
'cait_XXS36','cait_XXS36_224'
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@ -251,8 +252,93 @@ class cait_models(nn.Module):
|
|||||||
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
@register_model
|
||||||
|
def cait_XXS24_224(pretrained=False, **kwargs):
|
||||||
|
model = cait_models(
|
||||||
|
img_size= 224,patch_size=16, embed_dim=192, depth=24, num_heads=4, mlp_ratio=4, qkv_bias=True,
|
||||||
|
norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
||||||
|
init_scale=1e-5,
|
||||||
|
depth_token_only=2,**kwargs)
|
||||||
|
|
||||||
|
model.default_cfg = _cfg()
|
||||||
|
if pretrained:
|
||||||
|
checkpoint = torch.hub.load_state_dict_from_url(
|
||||||
|
url="https://dl.fbaipublicfiles.com/deit/XXS24_224.pth",
|
||||||
|
map_location="cpu", check_hash=True
|
||||||
|
)
|
||||||
|
checkpoint_no_module = {}
|
||||||
|
for k in model.state_dict().keys():
|
||||||
|
checkpoint_no_module[k] = checkpoint["model"]['module.'+k]
|
||||||
|
|
||||||
|
model.load_state_dict(checkpoint_no_module)
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
@register_model
|
||||||
|
def cait_XXS24(pretrained=False, **kwargs):
|
||||||
|
model = cait_models(
|
||||||
|
img_size= 384,patch_size=16, embed_dim=192, depth=24, num_heads=4, mlp_ratio=4, qkv_bias=True,
|
||||||
|
norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
||||||
|
init_scale=1e-5,
|
||||||
|
depth_token_only=2,**kwargs)
|
||||||
|
|
||||||
|
model.default_cfg = _cfg()
|
||||||
|
if pretrained:
|
||||||
|
checkpoint = torch.hub.load_state_dict_from_url(
|
||||||
|
url="https://dl.fbaipublicfiles.com/deit/XXS24_384.pth",
|
||||||
|
map_location="cpu", check_hash=True
|
||||||
|
)
|
||||||
|
checkpoint_no_module = {}
|
||||||
|
for k in model.state_dict().keys():
|
||||||
|
checkpoint_no_module[k] = checkpoint["model"]['module.'+k]
|
||||||
|
|
||||||
|
model.load_state_dict(checkpoint_no_module)
|
||||||
|
|
||||||
|
return model
|
||||||
|
@register_model
|
||||||
|
def cait_XXS36_224(pretrained=False, **kwargs):
|
||||||
|
model = cait_models(
|
||||||
|
img_size= 224,patch_size=16, embed_dim=192, depth=36, num_heads=4, mlp_ratio=4, qkv_bias=True,
|
||||||
|
norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
||||||
|
init_scale=1e-5,
|
||||||
|
depth_token_only=2,**kwargs)
|
||||||
|
|
||||||
|
model.default_cfg = _cfg()
|
||||||
|
if pretrained:
|
||||||
|
checkpoint = torch.hub.load_state_dict_from_url(
|
||||||
|
url="https://dl.fbaipublicfiles.com/deit/XXS36_224.pth",
|
||||||
|
map_location="cpu", check_hash=True
|
||||||
|
)
|
||||||
|
checkpoint_no_module = {}
|
||||||
|
for k in model.state_dict().keys():
|
||||||
|
checkpoint_no_module[k] = checkpoint["model"]['module.'+k]
|
||||||
|
|
||||||
|
model.load_state_dict(checkpoint_no_module)
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
@register_model
|
||||||
|
def cait_XXS36(pretrained=False, **kwargs):
|
||||||
|
model = cait_models(
|
||||||
|
img_size= 384,patch_size=16, embed_dim=192, depth=36, num_heads=4, mlp_ratio=4, qkv_bias=True,
|
||||||
|
norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
||||||
|
init_scale=1e-5,
|
||||||
|
depth_token_only=2,**kwargs)
|
||||||
|
|
||||||
|
model.default_cfg = _cfg()
|
||||||
|
if pretrained:
|
||||||
|
checkpoint = torch.hub.load_state_dict_from_url(
|
||||||
|
url="https://dl.fbaipublicfiles.com/deit/XXS36_384.pth",
|
||||||
|
map_location="cpu", check_hash=True
|
||||||
|
)
|
||||||
|
checkpoint_no_module = {}
|
||||||
|
for k in model.state_dict().keys():
|
||||||
|
checkpoint_no_module[k] = checkpoint["model"]['module.'+k]
|
||||||
|
|
||||||
|
model.load_state_dict(checkpoint_no_module)
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
def cait_XS24(pretrained=False, **kwargs):
|
def cait_XS24(pretrained=False, **kwargs):
|
||||||
model = cait_models(
|
model = cait_models(
|
||||||
|
Loading…
x
Reference in New Issue
Block a user