add medium

pull/190/head
Hugo Touvron 2022-07-28 09:35:02 +02:00 committed by GitHub
parent b2bae2ac3a
commit dbfa364e09
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 21 additions and 1 deletions

View File

@ -297,7 +297,27 @@ def deit_small_patch16_LS(pretrained=False, img_size=224, pretrained_21k = False
model.load_state_dict(checkpoint["model"])
return model
@register_model
def deit_medium_patch16_LS(pretrained=False, img_size=224, pretrained_21k = False, **kwargs):
model = vit_models(
patch_size=16, embed_dim=512, depth=12, num_heads=8, mlp_ratio=4, qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6),block_layers = Layer_scale_init_Block, **kwargs)
model.default_cfg = _cfg()
if pretrained:
name = 'https://dl.fbaipublicfiles.com/deit/deit_3_medium_'+str(img_size)+'_'
if pretrained_21k:
name+='21k.pth'
else:
name+='1k.pth'
checkpoint = torch.hub.load_state_dict_from_url(
url=name,
map_location="cpu", check_hash=True
)
model.load_state_dict(checkpoint["model"])
return model
@register_model
def deit_base_patch16_LS(pretrained=False, img_size=224, pretrained_21k = False, **kwargs):
model = vit_models(