diff --git a/models_v2.py b/models_v2.py index 6c95936..a59fe8e 100644 --- a/models_v2.py +++ b/models_v2.py @@ -297,7 +297,27 @@ def deit_small_patch16_LS(pretrained=False, img_size=224, pretrained_21k = False model.load_state_dict(checkpoint["model"]) return model - + +@register_model +def deit_medium_patch16_LS(pretrained=False, img_size=224, pretrained_21k = False, **kwargs): + model = vit_models( + patch_size=16, embed_dim=512, depth=12, num_heads=8, mlp_ratio=4, qkv_bias=True, + norm_layer=partial(nn.LayerNorm, eps=1e-6),block_layers = Layer_scale_init_Block, **kwargs) + model.default_cfg = _cfg() + if pretrained: + name = 'https://dl.fbaipublicfiles.com/deit/deit_3_medium_'+str(img_size)+'_' + if pretrained_21k: + name+='21k.pth' + else: + name+='1k.pth' + + checkpoint = torch.hub.load_state_dict_from_url( + url=name, + map_location="cpu", check_hash=True + ) + model.load_state_dict(checkpoint["model"]) + return model + @register_model def deit_base_patch16_LS(pretrained=False, img_size=224, pretrained_21k = False, **kwargs): model = vit_models(