# Copyright (c) 2015-present, Facebook, Inc. # All rights reserved. import torch import torch.nn as nn from functools import partial from timm.models.vision_transformer import VisionTransformer, _cfg from timm.models.registry import register_model @register_model def deit_tiny_patch16_224(pretrained=False, **kwargs): model = VisionTransformer( patch_size=16, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) model.default_cfg = _cfg() if pretrained: checkpoint = torch.hub.load_state_dict_from_url( url="https://dl.fbaipublicfiles.com/deit/deit_tiny_patch16_224-a1311bcf.pth", map_location="cpu", check_hash=True ) model.load_state_dict(checkpoint["model"]) return model @register_model def deit_small_patch16_224(pretrained=False, **kwargs): model = VisionTransformer( patch_size=16, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) model.default_cfg = _cfg() if pretrained: checkpoint = torch.hub.load_state_dict_from_url( url="https://dl.fbaipublicfiles.com/deit/deit_small_patch16_224-cd65a155.pth", map_location="cpu", check_hash=True ) model.load_state_dict(checkpoint["model"]) return model @register_model def deit_base_patch16_224(pretrained=False, **kwargs): model = VisionTransformer( patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) model.default_cfg = _cfg() if pretrained: checkpoint = torch.hub.load_state_dict_from_url( url="https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth", map_location="cpu", check_hash=True ) model.load_state_dict(checkpoint["model"]) return model