mirror of https://github.com/facebookresearch/deit
add medium
parent
b2bae2ac3a
commit
dbfa364e09
20
models_v2.py
20
models_v2.py
|
@ -298,6 +298,26 @@ def deit_small_patch16_LS(pretrained=False, img_size=224, pretrained_21k = False
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
@register_model
|
||||||
|
def deit_medium_patch16_LS(pretrained=False, img_size=224, pretrained_21k = False, **kwargs):
|
||||||
|
model = vit_models(
|
||||||
|
patch_size=16, embed_dim=512, depth=12, num_heads=8, mlp_ratio=4, qkv_bias=True,
|
||||||
|
norm_layer=partial(nn.LayerNorm, eps=1e-6),block_layers = Layer_scale_init_Block, **kwargs)
|
||||||
|
model.default_cfg = _cfg()
|
||||||
|
if pretrained:
|
||||||
|
name = 'https://dl.fbaipublicfiles.com/deit/deit_3_medium_'+str(img_size)+'_'
|
||||||
|
if pretrained_21k:
|
||||||
|
name+='21k.pth'
|
||||||
|
else:
|
||||||
|
name+='1k.pth'
|
||||||
|
|
||||||
|
checkpoint = torch.hub.load_state_dict_from_url(
|
||||||
|
url=name,
|
||||||
|
map_location="cpu", check_hash=True
|
||||||
|
)
|
||||||
|
model.load_state_dict(checkpoint["model"])
|
||||||
|
return model
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
def deit_base_patch16_LS(pretrained=False, img_size=224, pretrained_21k = False, **kwargs):
|
def deit_base_patch16_LS(pretrained=False, img_size=224, pretrained_21k = False, **kwargs):
|
||||||
model = vit_models(
|
model = vit_models(
|
||||||
|
|
Loading…
Reference in New Issue