mirror of https://github.com/facebookresearch/deit
add medium
parent
b2bae2ac3a
commit
dbfa364e09
22
models_v2.py
22
models_v2.py
|
@ -297,7 +297,27 @@ def deit_small_patch16_LS(pretrained=False, img_size=224, pretrained_21k = False
|
|||
model.load_state_dict(checkpoint["model"])
|
||||
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def deit_medium_patch16_LS(pretrained=False, img_size=224, pretrained_21k = False, **kwargs):
|
||||
model = vit_models(
|
||||
patch_size=16, embed_dim=512, depth=12, num_heads=8, mlp_ratio=4, qkv_bias=True,
|
||||
norm_layer=partial(nn.LayerNorm, eps=1e-6),block_layers = Layer_scale_init_Block, **kwargs)
|
||||
model.default_cfg = _cfg()
|
||||
if pretrained:
|
||||
name = 'https://dl.fbaipublicfiles.com/deit/deit_3_medium_'+str(img_size)+'_'
|
||||
if pretrained_21k:
|
||||
name+='21k.pth'
|
||||
else:
|
||||
name+='1k.pth'
|
||||
|
||||
checkpoint = torch.hub.load_state_dict_from_url(
|
||||
url=name,
|
||||
map_location="cpu", check_hash=True
|
||||
)
|
||||
model.load_state_dict(checkpoint["model"])
|
||||
return model
|
||||
|
||||
@register_model
|
||||
def deit_base_patch16_LS(pretrained=False, img_size=224, pretrained_21k = False, **kwargs):
|
||||
model = vit_models(
|
||||
|
|
Loading…
Reference in New Issue