mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
multi-weight and hf hub for deit / deit3
This commit is contained in:
parent
56b90317cd
commit
cff81deb78
@ -19,95 +19,11 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from timm.models.vision_transformer import VisionTransformer, trunc_normal_, checkpoint_filter_fn
|
||||
from ._builder import build_model_with_cfg
|
||||
from ._manipulate import checkpoint_seq
|
||||
from ._registry import register_model
|
||||
from ._registry import generate_default_cfgs, register_model, register_model_deprecations
|
||||
|
||||
__all__ = ['VisionTransformerDistilled'] # model_registry will add each entrypoint fn to this
|
||||
|
||||
|
||||
def _cfg(url='', **kwargs):
|
||||
return {
|
||||
'url': url,
|
||||
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
|
||||
'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True,
|
||||
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
|
||||
'first_conv': 'patch_embed.proj', 'classifier': 'head',
|
||||
**kwargs
|
||||
}
|
||||
|
||||
|
||||
default_cfgs = {
|
||||
# deit models (FB weights)
|
||||
'deit_tiny_patch16_224': _cfg(
|
||||
url='https://dl.fbaipublicfiles.com/deit/deit_tiny_patch16_224-a1311bcf.pth'),
|
||||
'deit_small_patch16_224': _cfg(
|
||||
url='https://dl.fbaipublicfiles.com/deit/deit_small_patch16_224-cd65a155.pth'),
|
||||
'deit_base_patch16_224': _cfg(
|
||||
url='https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth'),
|
||||
'deit_base_patch16_384': _cfg(
|
||||
url='https://dl.fbaipublicfiles.com/deit/deit_base_patch16_384-8de9b5d1.pth',
|
||||
input_size=(3, 384, 384), crop_pct=1.0),
|
||||
|
||||
'deit_tiny_distilled_patch16_224': _cfg(
|
||||
url='https://dl.fbaipublicfiles.com/deit/deit_tiny_distilled_patch16_224-b40b3cf7.pth',
|
||||
classifier=('head', 'head_dist')),
|
||||
'deit_small_distilled_patch16_224': _cfg(
|
||||
url='https://dl.fbaipublicfiles.com/deit/deit_small_distilled_patch16_224-649709d9.pth',
|
||||
classifier=('head', 'head_dist')),
|
||||
'deit_base_distilled_patch16_224': _cfg(
|
||||
url='https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_224-df68dfff.pth',
|
||||
classifier=('head', 'head_dist')),
|
||||
'deit_base_distilled_patch16_384': _cfg(
|
||||
url='https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_384-d0272ac0.pth',
|
||||
input_size=(3, 384, 384), crop_pct=1.0,
|
||||
classifier=('head', 'head_dist')),
|
||||
|
||||
'deit3_small_patch16_224': _cfg(
|
||||
url='https://dl.fbaipublicfiles.com/deit/deit_3_small_224_1k.pth'),
|
||||
'deit3_small_patch16_384': _cfg(
|
||||
url='https://dl.fbaipublicfiles.com/deit/deit_3_small_384_1k.pth',
|
||||
input_size=(3, 384, 384), crop_pct=1.0),
|
||||
'deit3_medium_patch16_224': _cfg(
|
||||
url='https://dl.fbaipublicfiles.com/deit/deit_3_medium_224_1k.pth'),
|
||||
'deit3_base_patch16_224': _cfg(
|
||||
url='https://dl.fbaipublicfiles.com/deit/deit_3_base_224_1k.pth'),
|
||||
'deit3_base_patch16_384': _cfg(
|
||||
url='https://dl.fbaipublicfiles.com/deit/deit_3_base_384_1k.pth',
|
||||
input_size=(3, 384, 384), crop_pct=1.0),
|
||||
'deit3_large_patch16_224': _cfg(
|
||||
url='https://dl.fbaipublicfiles.com/deit/deit_3_large_224_1k.pth'),
|
||||
'deit3_large_patch16_384': _cfg(
|
||||
url='https://dl.fbaipublicfiles.com/deit/deit_3_large_384_1k.pth',
|
||||
input_size=(3, 384, 384), crop_pct=1.0),
|
||||
'deit3_huge_patch14_224': _cfg(
|
||||
url='https://dl.fbaipublicfiles.com/deit/deit_3_huge_224_1k.pth'),
|
||||
|
||||
'deit3_small_patch16_224_in21ft1k': _cfg(
|
||||
url='https://dl.fbaipublicfiles.com/deit/deit_3_small_224_21k.pth',
|
||||
crop_pct=1.0),
|
||||
'deit3_small_patch16_384_in21ft1k': _cfg(
|
||||
url='https://dl.fbaipublicfiles.com/deit/deit_3_small_384_21k.pth',
|
||||
input_size=(3, 384, 384), crop_pct=1.0),
|
||||
'deit3_medium_patch16_224_in21ft1k': _cfg(
|
||||
url='https://dl.fbaipublicfiles.com/deit/deit_3_medium_224_21k.pth',
|
||||
crop_pct=1.0),
|
||||
'deit3_base_patch16_224_in21ft1k': _cfg(
|
||||
url='https://dl.fbaipublicfiles.com/deit/deit_3_base_224_21k.pth',
|
||||
crop_pct=1.0),
|
||||
'deit3_base_patch16_384_in21ft1k': _cfg(
|
||||
url='https://dl.fbaipublicfiles.com/deit/deit_3_base_384_21k.pth',
|
||||
input_size=(3, 384, 384), crop_pct=1.0),
|
||||
'deit3_large_patch16_224_in21ft1k': _cfg(
|
||||
url='https://dl.fbaipublicfiles.com/deit/deit_3_large_224_21k.pth',
|
||||
crop_pct=1.0),
|
||||
'deit3_large_patch16_384_in21ft1k': _cfg(
|
||||
url='https://dl.fbaipublicfiles.com/deit/deit_3_large_384_21k.pth',
|
||||
input_size=(3, 384, 384), crop_pct=1.0),
|
||||
'deit3_huge_patch14_224_in21ft1k': _cfg(
|
||||
url='https://dl.fbaipublicfiles.com/deit/deit_3_huge_224_21k_v1.pth',
|
||||
crop_pct=1.0),
|
||||
}
|
||||
|
||||
|
||||
class VisionTransformerDistilled(VisionTransformer):
|
||||
""" Vision Transformer w/ Distillation Token and Head
|
||||
|
||||
@ -159,7 +75,8 @@ class VisionTransformerDistilled(VisionTransformer):
|
||||
x = self.patch_embed(x)
|
||||
x = torch.cat((
|
||||
self.cls_token.expand(x.shape[0], -1, -1),
|
||||
self.dist_token.expand(x.shape[0], -1, -1), x), dim=1)
|
||||
self.dist_token.expand(x.shape[0], -1, -1), x),
|
||||
dim=1)
|
||||
x = self.pos_drop(x + self.pos_embed)
|
||||
if self.grad_checkpointing and not torch.jit.is_scripting():
|
||||
x = checkpoint_seq(self.blocks, x)
|
||||
@ -169,9 +86,11 @@ class VisionTransformerDistilled(VisionTransformer):
|
||||
return x
|
||||
|
||||
def forward_head(self, x, pre_logits: bool = False) -> torch.Tensor:
|
||||
x, x_dist = x[:, 0], x[:, 1]
|
||||
if pre_logits:
|
||||
return (x[:, 0] + x[:, 1]) / 2
|
||||
x, x_dist = self.head(x[:, 0]), self.head_dist(x[:, 1])
|
||||
return (x + x_dist) / 2
|
||||
x = self.head(x)
|
||||
x_dist = self.head_dist(x_dist)
|
||||
if self.distilled_training and self.training and not torch.jit.is_scripting():
|
||||
# only return separate classification predictions when training in distilled mode
|
||||
return x, x_dist
|
||||
@ -185,12 +104,123 @@ def _create_deit(variant, pretrained=False, distilled=False, **kwargs):
|
||||
raise RuntimeError('features_only not implemented for Vision Transformer models.')
|
||||
model_cls = VisionTransformerDistilled if distilled else VisionTransformer
|
||||
model = build_model_with_cfg(
|
||||
model_cls, variant, pretrained,
|
||||
model_cls,
|
||||
variant,
|
||||
pretrained,
|
||||
pretrained_filter_fn=partial(checkpoint_filter_fn, adapt_layer_scale=True),
|
||||
**kwargs)
|
||||
**kwargs,
|
||||
)
|
||||
return model
|
||||
|
||||
|
||||
def _cfg(url='', **kwargs):
|
||||
return {
|
||||
'url': url,
|
||||
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
|
||||
'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True,
|
||||
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
|
||||
'first_conv': 'patch_embed.proj', 'classifier': 'head',
|
||||
**kwargs
|
||||
}
|
||||
|
||||
|
||||
default_cfgs = generate_default_cfgs({
|
||||
# deit models (FB weights)
|
||||
'deit_tiny_patch16_224.fb_in1k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
url='https://dl.fbaipublicfiles.com/deit/deit_tiny_patch16_224-a1311bcf.pth'),
|
||||
'deit_small_patch16_224.fb_in1k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
url='https://dl.fbaipublicfiles.com/deit/deit_small_patch16_224-cd65a155.pth'),
|
||||
'deit_base_patch16_224.fb_in1k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
url='https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth'),
|
||||
'deit_base_patch16_384.fb_in1k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
url='https://dl.fbaipublicfiles.com/deit/deit_base_patch16_384-8de9b5d1.pth',
|
||||
input_size=(3, 384, 384), crop_pct=1.0),
|
||||
|
||||
'deit_tiny_distilled_patch16_224.fb_in1k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
url='https://dl.fbaipublicfiles.com/deit/deit_tiny_distilled_patch16_224-b40b3cf7.pth',
|
||||
classifier=('head', 'head_dist')),
|
||||
'deit_small_distilled_patch16_224.fb_in1k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
url='https://dl.fbaipublicfiles.com/deit/deit_small_distilled_patch16_224-649709d9.pth',
|
||||
classifier=('head', 'head_dist')),
|
||||
'deit_base_distilled_patch16_224.fb_in1k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
url='https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_224-df68dfff.pth',
|
||||
classifier=('head', 'head_dist')),
|
||||
'deit_base_distilled_patch16_384.fb_in1k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
url='https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_384-d0272ac0.pth',
|
||||
input_size=(3, 384, 384), crop_pct=1.0,
|
||||
classifier=('head', 'head_dist')),
|
||||
|
||||
'deit3_small_patch16_224.fb_in1k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
url='https://dl.fbaipublicfiles.com/deit/deit_3_small_224_1k.pth'),
|
||||
'deit3_small_patch16_384.fb_in1k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
url='https://dl.fbaipublicfiles.com/deit/deit_3_small_384_1k.pth',
|
||||
input_size=(3, 384, 384), crop_pct=1.0),
|
||||
'deit3_medium_patch16_224.fb_in1k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
url='https://dl.fbaipublicfiles.com/deit/deit_3_medium_224_1k.pth'),
|
||||
'deit3_base_patch16_224.fb_in1k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
url='https://dl.fbaipublicfiles.com/deit/deit_3_base_224_1k.pth'),
|
||||
'deit3_base_patch16_384.fb_in1k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
url='https://dl.fbaipublicfiles.com/deit/deit_3_base_384_1k.pth',
|
||||
input_size=(3, 384, 384), crop_pct=1.0),
|
||||
'deit3_large_patch16_224.fb_in1k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
url='https://dl.fbaipublicfiles.com/deit/deit_3_large_224_1k.pth'),
|
||||
'deit3_large_patch16_384.fb_in1k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
url='https://dl.fbaipublicfiles.com/deit/deit_3_large_384_1k.pth',
|
||||
input_size=(3, 384, 384), crop_pct=1.0),
|
||||
'deit3_huge_patch14_224.fb_in1k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
url='https://dl.fbaipublicfiles.com/deit/deit_3_huge_224_1k.pth'),
|
||||
|
||||
'deit3_small_patch16_224.fb_in22k_ft_in1k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
url='https://dl.fbaipublicfiles.com/deit/deit_3_small_224_21k.pth',
|
||||
crop_pct=1.0),
|
||||
'deit3_small_patch16_384.fb_in22k_ft_in1k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
url='https://dl.fbaipublicfiles.com/deit/deit_3_small_384_21k.pth',
|
||||
input_size=(3, 384, 384), crop_pct=1.0),
|
||||
'deit3_medium_patch16_224.fb_in22k_ft_in1k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
url='https://dl.fbaipublicfiles.com/deit/deit_3_medium_224_21k.pth',
|
||||
crop_pct=1.0),
|
||||
'deit3_base_patch16_224.fb_in22k_ft_in1k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
url='https://dl.fbaipublicfiles.com/deit/deit_3_base_224_21k.pth',
|
||||
crop_pct=1.0),
|
||||
'deit3_base_patch16_384.fb_in22k_ft_in1k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
url='https://dl.fbaipublicfiles.com/deit/deit_3_base_384_21k.pth',
|
||||
input_size=(3, 384, 384), crop_pct=1.0),
|
||||
'deit3_large_patch16_224.fb_in22k_ft_in1k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
url='https://dl.fbaipublicfiles.com/deit/deit_3_large_224_21k.pth',
|
||||
crop_pct=1.0),
|
||||
'deit3_large_patch16_384.fb_in22k_ft_in1k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
url='https://dl.fbaipublicfiles.com/deit/deit_3_large_384_21k.pth',
|
||||
input_size=(3, 384, 384), crop_pct=1.0),
|
||||
'deit3_huge_patch14_224.fb_in22k_ft_in1k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
url='https://dl.fbaipublicfiles.com/deit/deit_3_huge_224_21k_v1.pth',
|
||||
crop_pct=1.0),
|
||||
})
|
||||
|
||||
|
||||
@register_model
|
||||
def deit_tiny_patch16_224(pretrained=False, **kwargs):
|
||||
""" DeiT-tiny model @ 224x224 from paper (https://arxiv.org/abs/2012.12877).
|
||||
@ -363,89 +393,13 @@ def deit3_huge_patch14_224(pretrained=False, **kwargs):
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def deit3_small_patch16_224_in21ft1k(pretrained=False, **kwargs):
|
||||
""" DeiT-3 small model @ 224x224 from paper (https://arxiv.org/abs/2204.07118).
|
||||
ImageNet-21k pretrained weights from https://github.com/facebookresearch/deit.
|
||||
"""
|
||||
model_kwargs = dict(
|
||||
patch_size=16, embed_dim=384, depth=12, num_heads=6, no_embed_class=True, init_values=1e-6, **kwargs)
|
||||
model = _create_deit('deit3_small_patch16_224_in21ft1k', pretrained=pretrained, **model_kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def deit3_small_patch16_384_in21ft1k(pretrained=False, **kwargs):
|
||||
""" DeiT-3 small model @ 384x384 from paper (https://arxiv.org/abs/2204.07118).
|
||||
ImageNet-21k pretrained weights from https://github.com/facebookresearch/deit.
|
||||
"""
|
||||
model_kwargs = dict(
|
||||
patch_size=16, embed_dim=384, depth=12, num_heads=6, no_embed_class=True, init_values=1e-6, **kwargs)
|
||||
model = _create_deit('deit3_small_patch16_384_in21ft1k', pretrained=pretrained, **model_kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def deit3_medium_patch16_224_in21ft1k(pretrained=False, **kwargs):
|
||||
""" DeiT-3 medium model @ 224x224 (https://arxiv.org/abs/2012.12877).
|
||||
ImageNet-1k weights from https://github.com/facebookresearch/deit.
|
||||
"""
|
||||
model_kwargs = dict(
|
||||
patch_size=16, embed_dim=512, depth=12, num_heads=8, no_embed_class=True, init_values=1e-6, **kwargs)
|
||||
model = _create_deit('deit3_medium_patch16_224_in21ft1k', pretrained=pretrained, **model_kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def deit3_base_patch16_224_in21ft1k(pretrained=False, **kwargs):
|
||||
""" DeiT-3 base model @ 224x224 from paper (https://arxiv.org/abs/2204.07118).
|
||||
ImageNet-21k pretrained weights from https://github.com/facebookresearch/deit.
|
||||
"""
|
||||
model_kwargs = dict(
|
||||
patch_size=16, embed_dim=768, depth=12, num_heads=12, no_embed_class=True, init_values=1e-6, **kwargs)
|
||||
model = _create_deit('deit3_base_patch16_224_in21ft1k', pretrained=pretrained, **model_kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def deit3_base_patch16_384_in21ft1k(pretrained=False, **kwargs):
|
||||
""" DeiT-3 base model @ 384x384 from paper (https://arxiv.org/abs/2204.07118).
|
||||
ImageNet-21k pretrained weights from https://github.com/facebookresearch/deit.
|
||||
"""
|
||||
model_kwargs = dict(
|
||||
patch_size=16, embed_dim=768, depth=12, num_heads=12, no_embed_class=True, init_values=1e-6, **kwargs)
|
||||
model = _create_deit('deit3_base_patch16_384_in21ft1k', pretrained=pretrained, **model_kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def deit3_large_patch16_224_in21ft1k(pretrained=False, **kwargs):
|
||||
""" DeiT-3 large model @ 224x224 from paper (https://arxiv.org/abs/2204.07118).
|
||||
ImageNet-21k pretrained weights from https://github.com/facebookresearch/deit.
|
||||
"""
|
||||
model_kwargs = dict(
|
||||
patch_size=16, embed_dim=1024, depth=24, num_heads=16, no_embed_class=True, init_values=1e-6, **kwargs)
|
||||
model = _create_deit('deit3_large_patch16_224_in21ft1k', pretrained=pretrained, **model_kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def deit3_large_patch16_384_in21ft1k(pretrained=False, **kwargs):
|
||||
""" DeiT-3 large model @ 384x384 from paper (https://arxiv.org/abs/2204.07118).
|
||||
ImageNet-21k pretrained weights from https://github.com/facebookresearch/deit.
|
||||
"""
|
||||
model_kwargs = dict(
|
||||
patch_size=16, embed_dim=1024, depth=24, num_heads=16, no_embed_class=True, init_values=1e-6, **kwargs)
|
||||
model = _create_deit('deit3_large_patch16_384_in21ft1k', pretrained=pretrained, **model_kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def deit3_huge_patch14_224_in21ft1k(pretrained=False, **kwargs):
|
||||
""" DeiT-3 base model @ 384x384 from paper (https://arxiv.org/abs/2204.07118).
|
||||
ImageNet-21k pretrained weights from https://github.com/facebookresearch/deit.
|
||||
"""
|
||||
model_kwargs = dict(
|
||||
patch_size=14, embed_dim=1280, depth=32, num_heads=16, no_embed_class=True, init_values=1e-6, **kwargs)
|
||||
model = _create_deit('deit3_huge_patch14_224_in21ft1k', pretrained=pretrained, **model_kwargs)
|
||||
return model
|
||||
register_model_deprecations(__name__, {
|
||||
'deit3_small_patch16_224_in21ft1k': 'deit3_small_patch16_224.fb_in22k_ft_in1k',
|
||||
'deit3_small_patch16_384_in21ft1k': 'deit3_small_patch16_384.fb_in22k_ft_in1k',
|
||||
'deit3_medium_patch16_224_in21ft1k': 'deit3_medium_patch16_224.fb_in22k_ft_in1k',
|
||||
'deit3_base_patch16_224_in21ft1k': 'deit3_base_patch16_224.fb_in22k_ft_in1k',
|
||||
'deit3_base_patch16_384_in21ft1k': 'deit3_base_patch16_384.fb_in22k_ft_in1k',
|
||||
'deit3_large_patch16_224_in21ft1k': 'deit3_large_patch16_224.fb_in22k_ft_in1k',
|
||||
'deit3_large_patch16_384_in21ft1k': 'deit3_large_patch16_384.fb_in22k_ft_in1k',
|
||||
'deit3_huge_patch14_224_in21ft1k': 'deit3_huge_patch14_224.fb_in22k_ft_in1k'
|
||||
})
|
||||
|
Loading…
x
Reference in New Issue
Block a user