mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
MLP-Mixer multi-weight support, hf hub push
This commit is contained in:
parent
56b90317cd
commit
b12060996c
@ -48,101 +48,9 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from timm.layers import PatchEmbed, Mlp, GluMlp, GatedMlp, DropPath, lecun_normal_, to_2tuple
|
||||
from ._builder import build_model_with_cfg
|
||||
from ._manipulate import named_apply, checkpoint_seq
|
||||
from ._registry import register_model
|
||||
from ._registry import generate_default_cfgs, register_model, register_model_deprecations
|
||||
|
||||
__all__ = ['MixerBlock'] # model_registry will add each entrypoint fn to this
|
||||
|
||||
|
||||
def _cfg(url='', **kwargs):
|
||||
return {
|
||||
'url': url,
|
||||
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
|
||||
'crop_pct': 0.875, 'interpolation': 'bicubic', 'fixed_input_size': True,
|
||||
'mean': (0.5, 0.5, 0.5), 'std': (0.5, 0.5, 0.5),
|
||||
'first_conv': 'stem.proj', 'classifier': 'head',
|
||||
**kwargs
|
||||
}
|
||||
|
||||
|
||||
default_cfgs = dict(
|
||||
mixer_s32_224=_cfg(),
|
||||
mixer_s16_224=_cfg(),
|
||||
mixer_b32_224=_cfg(),
|
||||
mixer_b16_224=_cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_mixer_b16_224-76587d61.pth',
|
||||
),
|
||||
mixer_b16_224_in21k=_cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_mixer_b16_224_in21k-617b3de2.pth',
|
||||
num_classes=21843
|
||||
),
|
||||
mixer_l32_224=_cfg(),
|
||||
mixer_l16_224=_cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_mixer_l16_224-92f9adc4.pth',
|
||||
),
|
||||
mixer_l16_224_in21k=_cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_mixer_l16_224_in21k-846aa33c.pth',
|
||||
num_classes=21843
|
||||
),
|
||||
|
||||
# Mixer ImageNet-21K-P pretraining
|
||||
mixer_b16_224_miil_in21k=_cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/mixer_b16_224_miil_in21k-2a558a71.pth',
|
||||
mean=(0., 0., 0.), std=(1., 1., 1.), crop_pct=0.875, interpolation='bilinear', num_classes=11221,
|
||||
),
|
||||
mixer_b16_224_miil=_cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/mixer_b16_224_miil-9229a591.pth',
|
||||
mean=(0., 0., 0.), std=(1., 1., 1.), crop_pct=0.875, interpolation='bilinear',
|
||||
),
|
||||
|
||||
gmixer_12_224=_cfg(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
|
||||
gmixer_24_224=_cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/gmixer_24_224_raa-7daf7ae6.pth',
|
||||
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
|
||||
|
||||
resmlp_12_224=_cfg(
|
||||
url='https://dl.fbaipublicfiles.com/deit/resmlp_12_no_dist.pth',
|
||||
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
|
||||
resmlp_24_224=_cfg(
|
||||
url='https://dl.fbaipublicfiles.com/deit/resmlp_24_no_dist.pth',
|
||||
#url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resmlp_24_224_raa-a8256759.pth',
|
||||
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
|
||||
resmlp_36_224=_cfg(
|
||||
url='https://dl.fbaipublicfiles.com/deit/resmlp_36_no_dist.pth',
|
||||
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
|
||||
resmlp_big_24_224=_cfg(
|
||||
url='https://dl.fbaipublicfiles.com/deit/resmlpB_24_no_dist.pth',
|
||||
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
|
||||
|
||||
resmlp_12_distilled_224=_cfg(
|
||||
url='https://dl.fbaipublicfiles.com/deit/resmlp_12_dist.pth',
|
||||
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
|
||||
resmlp_24_distilled_224=_cfg(
|
||||
url='https://dl.fbaipublicfiles.com/deit/resmlp_24_dist.pth',
|
||||
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
|
||||
resmlp_36_distilled_224=_cfg(
|
||||
url='https://dl.fbaipublicfiles.com/deit/resmlp_36_dist.pth',
|
||||
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
|
||||
resmlp_big_24_distilled_224=_cfg(
|
||||
url='https://dl.fbaipublicfiles.com/deit/resmlpB_24_dist.pth',
|
||||
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
|
||||
|
||||
resmlp_big_24_224_in22ft1k=_cfg(
|
||||
url='https://dl.fbaipublicfiles.com/deit/resmlpB_24_22k.pth',
|
||||
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
|
||||
|
||||
resmlp_12_224_dino=_cfg(
|
||||
url='https://dl.fbaipublicfiles.com/deit/resmlp_12_dino.pth',
|
||||
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
|
||||
resmlp_24_224_dino=_cfg(
|
||||
url='https://dl.fbaipublicfiles.com/deit/resmlp_24_dino.pth',
|
||||
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
|
||||
|
||||
gmlp_ti16_224=_cfg(),
|
||||
gmlp_s16_224=_cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/gmlp_s16_224_raa-10536d42.pth',
|
||||
),
|
||||
gmlp_b16_224=_cfg(),
|
||||
)
|
||||
__all__ = ['MixerBlock', 'MlpMixer'] # model_registry will add each entrypoint fn to this
|
||||
|
||||
|
||||
class MixerBlock(nn.Module):
|
||||
@ -150,8 +58,16 @@ class MixerBlock(nn.Module):
|
||||
Based on: 'MLP-Mixer: An all-MLP Architecture for Vision' - https://arxiv.org/abs/2105.01601
|
||||
"""
|
||||
def __init__(
|
||||
self, dim, seq_len, mlp_ratio=(0.5, 4.0), mlp_layer=Mlp,
|
||||
norm_layer=partial(nn.LayerNorm, eps=1e-6), act_layer=nn.GELU, drop=0., drop_path=0.):
|
||||
self,
|
||||
dim,
|
||||
seq_len,
|
||||
mlp_ratio=(0.5, 4.0),
|
||||
mlp_layer=Mlp,
|
||||
norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
||||
act_layer=nn.GELU,
|
||||
drop=0.,
|
||||
drop_path=0.,
|
||||
):
|
||||
super().__init__()
|
||||
tokens_dim, channels_dim = [int(x * dim) for x in to_2tuple(mlp_ratio)]
|
||||
self.norm1 = norm_layer(dim)
|
||||
@ -182,8 +98,17 @@ class ResBlock(nn.Module):
|
||||
Based on: `ResMLP: Feedforward networks for image classification...` - https://arxiv.org/abs/2105.03404
|
||||
"""
|
||||
def __init__(
|
||||
self, dim, seq_len, mlp_ratio=4, mlp_layer=Mlp, norm_layer=Affine,
|
||||
act_layer=nn.GELU, init_values=1e-4, drop=0., drop_path=0.):
|
||||
self,
|
||||
dim,
|
||||
seq_len,
|
||||
mlp_ratio=4,
|
||||
mlp_layer=Mlp,
|
||||
norm_layer=Affine,
|
||||
act_layer=nn.GELU,
|
||||
init_values=1e-4,
|
||||
drop=0.,
|
||||
drop_path=0.,
|
||||
):
|
||||
super().__init__()
|
||||
channel_dim = int(dim * mlp_ratio)
|
||||
self.norm1 = norm_layer(dim)
|
||||
@ -229,8 +154,16 @@ class SpatialGatingBlock(nn.Module):
|
||||
Based on: `Pay Attention to MLPs` - https://arxiv.org/abs/2105.08050
|
||||
"""
|
||||
def __init__(
|
||||
self, dim, seq_len, mlp_ratio=4, mlp_layer=GatedMlp,
|
||||
norm_layer=partial(nn.LayerNorm, eps=1e-6), act_layer=nn.GELU, drop=0., drop_path=0.):
|
||||
self,
|
||||
dim,
|
||||
seq_len,
|
||||
mlp_ratio=4,
|
||||
mlp_layer=GatedMlp,
|
||||
norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
||||
act_layer=nn.GELU,
|
||||
drop=0.,
|
||||
drop_path=0.,
|
||||
):
|
||||
super().__init__()
|
||||
channel_dim = int(dim * mlp_ratio)
|
||||
self.norm = norm_layer(dim)
|
||||
@ -271,13 +204,24 @@ class MlpMixer(nn.Module):
|
||||
self.grad_checkpointing = False
|
||||
|
||||
self.stem = PatchEmbed(
|
||||
img_size=img_size, patch_size=patch_size, in_chans=in_chans,
|
||||
embed_dim=embed_dim, norm_layer=norm_layer if stem_norm else None)
|
||||
img_size=img_size,
|
||||
patch_size=patch_size,
|
||||
in_chans=in_chans,
|
||||
embed_dim=embed_dim,
|
||||
norm_layer=norm_layer if stem_norm else None,
|
||||
)
|
||||
# FIXME drop_path (stochastic depth scaling rule or all the same?)
|
||||
self.blocks = nn.Sequential(*[
|
||||
block_layer(
|
||||
embed_dim, self.stem.num_patches, mlp_ratio, mlp_layer=mlp_layer, norm_layer=norm_layer,
|
||||
act_layer=act_layer, drop=drop_rate, drop_path=drop_path_rate)
|
||||
embed_dim,
|
||||
self.stem.num_patches,
|
||||
mlp_ratio,
|
||||
mlp_layer=mlp_layer,
|
||||
norm_layer=norm_layer,
|
||||
act_layer=act_layer,
|
||||
drop=drop_rate,
|
||||
drop_path=drop_path_rate,
|
||||
)
|
||||
for _ in range(num_blocks)])
|
||||
self.norm = norm_layer(embed_dim)
|
||||
self.head = nn.Linear(embed_dim, self.num_classes) if num_classes > 0 else nn.Identity()
|
||||
@ -320,11 +264,14 @@ class MlpMixer(nn.Module):
|
||||
x = self.norm(x)
|
||||
return x
|
||||
|
||||
def forward(self, x):
|
||||
x = self.forward_features(x)
|
||||
def forward_head(self, x, pre_logits: bool = False):
|
||||
if self.global_pool == 'avg':
|
||||
x = x.mean(dim=1)
|
||||
x = self.head(x)
|
||||
return x if pre_logits else self.head(x)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.forward_features(x)
|
||||
x = self.forward_head(x)
|
||||
return x
|
||||
|
||||
|
||||
@ -384,12 +331,107 @@ def _create_mixer(variant, pretrained=False, **kwargs):
|
||||
raise RuntimeError('features_only not implemented for MLP-Mixer models.')
|
||||
|
||||
model = build_model_with_cfg(
|
||||
MlpMixer, variant, pretrained,
|
||||
MlpMixer,
|
||||
variant,
|
||||
pretrained,
|
||||
pretrained_filter_fn=checkpoint_filter_fn,
|
||||
**kwargs)
|
||||
**kwargs,
|
||||
)
|
||||
return model
|
||||
|
||||
|
||||
def _cfg(url='', **kwargs):
|
||||
return {
|
||||
'url': url,
|
||||
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
|
||||
'crop_pct': 0.875, 'interpolation': 'bicubic', 'fixed_input_size': True,
|
||||
'mean': (0.5, 0.5, 0.5), 'std': (0.5, 0.5, 0.5),
|
||||
'first_conv': 'stem.proj', 'classifier': 'head',
|
||||
**kwargs
|
||||
}
|
||||
|
||||
|
||||
default_cfgs = generate_default_cfgs({
|
||||
'mixer_s32_224.untrained': _cfg(),
|
||||
'mixer_s16_224.untrained': _cfg(),
|
||||
'mixer_b32_224.untrained': _cfg(),
|
||||
'mixer_b16_224.goog_in21k_ft_in1k': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_mixer_b16_224-76587d61.pth',
|
||||
),
|
||||
'mixer_b16_224.goog_in21k': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_mixer_b16_224_in21k-617b3de2.pth',
|
||||
num_classes=21843
|
||||
),
|
||||
'mixer_l32_224.untrained': _cfg(),
|
||||
'mixer_l16_224.goog_in21k_ft_in1k': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_mixer_l16_224-92f9adc4.pth',
|
||||
),
|
||||
'mixer_l16_224.goog_in21k': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_mixer_l16_224_in21k-846aa33c.pth',
|
||||
num_classes=21843
|
||||
),
|
||||
|
||||
# Mixer ImageNet-21K-P pretraining
|
||||
'mixer_b16_224.miil_in21k': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/mixer_b16_224_miil_in21k-2a558a71.pth',
|
||||
mean=(0., 0., 0.), std=(1., 1., 1.), crop_pct=0.875, interpolation='bilinear', num_classes=11221,
|
||||
),
|
||||
'mixer_b16_224.miil_in21k_ft_in1k': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/mixer_b16_224_miil-9229a591.pth',
|
||||
mean=(0., 0., 0.), std=(1., 1., 1.), crop_pct=0.875, interpolation='bilinear',
|
||||
),
|
||||
|
||||
'gmixer_12_224.untrained': _cfg(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
|
||||
'gmixer_24_224.ra3_in1k': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/gmixer_24_224_raa-7daf7ae6.pth',
|
||||
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
|
||||
|
||||
'resmlp_12_224.fb_in1k': _cfg(
|
||||
url='https://dl.fbaipublicfiles.com/deit/resmlp_12_no_dist.pth',
|
||||
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
|
||||
'resmlp_24_224.fb_in1k': _cfg(
|
||||
url='https://dl.fbaipublicfiles.com/deit/resmlp_24_no_dist.pth',
|
||||
#url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resmlp_24_224_raa-a8256759.pth',
|
||||
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
|
||||
'resmlp_36_224.fb_in1k': _cfg(
|
||||
url='https://dl.fbaipublicfiles.com/deit/resmlp_36_no_dist.pth',
|
||||
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
|
||||
'resmlp_big_24_224.fb_in1k': _cfg(
|
||||
url='https://dl.fbaipublicfiles.com/deit/resmlpB_24_no_dist.pth',
|
||||
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
|
||||
|
||||
'resmlp_12_224.fb_distilled_in1k': _cfg(
|
||||
url='https://dl.fbaipublicfiles.com/deit/resmlp_12_dist.pth',
|
||||
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
|
||||
'resmlp_24_224.fb_distilled_in1k': _cfg(
|
||||
url='https://dl.fbaipublicfiles.com/deit/resmlp_24_dist.pth',
|
||||
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
|
||||
'resmlp_36_224.fb_distilled_in1k': _cfg(
|
||||
url='https://dl.fbaipublicfiles.com/deit/resmlp_36_dist.pth',
|
||||
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
|
||||
'resmlp_big_24_224.fb_distilled_in1k': _cfg(
|
||||
url='https://dl.fbaipublicfiles.com/deit/resmlpB_24_dist.pth',
|
||||
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
|
||||
|
||||
'resmlp_big_24_224.fb_in22k_ft_in1k': _cfg(
|
||||
url='https://dl.fbaipublicfiles.com/deit/resmlpB_24_22k.pth',
|
||||
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
|
||||
|
||||
'resmlp_12_224.fb_dino': _cfg(
|
||||
url='https://dl.fbaipublicfiles.com/deit/resmlp_12_dino.pth',
|
||||
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
|
||||
'resmlp_24_224.fb_dino': _cfg(
|
||||
url='https://dl.fbaipublicfiles.com/deit/resmlp_24_dino.pth',
|
||||
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
|
||||
|
||||
'gmlp_ti16_224.untrained': _cfg(),
|
||||
'gmlp_s16_224.ra3_in1k': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/gmlp_s16_224_raa-10536d42.pth',
|
||||
),
|
||||
'gmlp_b16_224.untrained': _cfg(),
|
||||
})
|
||||
|
||||
|
||||
@register_model
|
||||
def mixer_s32_224(pretrained=False, **kwargs):
|
||||
""" Mixer-S/32 224x224
|
||||
@ -430,16 +472,6 @@ def mixer_b16_224(pretrained=False, **kwargs):
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def mixer_b16_224_in21k(pretrained=False, **kwargs):
|
||||
""" Mixer-B/16 224x224. ImageNet-21k pretrained weights.
|
||||
Paper: 'MLP-Mixer: An all-MLP Architecture for Vision' - https://arxiv.org/abs/2105.01601
|
||||
"""
|
||||
model_args = dict(patch_size=16, num_blocks=12, embed_dim=768, **kwargs)
|
||||
model = _create_mixer('mixer_b16_224_in21k', pretrained=pretrained, **model_args)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def mixer_l32_224(pretrained=False, **kwargs):
|
||||
""" Mixer-L/32 224x224.
|
||||
@ -460,40 +492,10 @@ def mixer_l16_224(pretrained=False, **kwargs):
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def mixer_l16_224_in21k(pretrained=False, **kwargs):
|
||||
""" Mixer-L/16 224x224. ImageNet-21k pretrained weights.
|
||||
Paper: 'MLP-Mixer: An all-MLP Architecture for Vision' - https://arxiv.org/abs/2105.01601
|
||||
"""
|
||||
model_args = dict(patch_size=16, num_blocks=24, embed_dim=1024, **kwargs)
|
||||
model = _create_mixer('mixer_l16_224_in21k', pretrained=pretrained, **model_args)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def mixer_b16_224_miil(pretrained=False, **kwargs):
|
||||
""" Mixer-B/16 224x224. ImageNet-21k pretrained weights.
|
||||
Weights taken from: https://github.com/Alibaba-MIIL/ImageNet21K
|
||||
"""
|
||||
model_args = dict(patch_size=16, num_blocks=12, embed_dim=768, **kwargs)
|
||||
model = _create_mixer('mixer_b16_224_miil', pretrained=pretrained, **model_args)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def mixer_b16_224_miil_in21k(pretrained=False, **kwargs):
|
||||
""" Mixer-B/16 224x224. ImageNet-1k pretrained weights.
|
||||
Weights taken from: https://github.com/Alibaba-MIIL/ImageNet21K
|
||||
"""
|
||||
model_args = dict(patch_size=16, num_blocks=12, embed_dim=768, **kwargs)
|
||||
model = _create_mixer('mixer_b16_224_miil_in21k', pretrained=pretrained, **model_args)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def gmixer_12_224(pretrained=False, **kwargs):
|
||||
""" Glu-Mixer-12 224x224
|
||||
Experiment by Ross Wightman, adding (Si)GLU to MLP-Mixer
|
||||
Experiment by Ross Wightman, adding SwiGLU to MLP-Mixer
|
||||
"""
|
||||
model_args = dict(
|
||||
patch_size=16, num_blocks=12, embed_dim=384, mlp_ratio=(1.0, 4.0),
|
||||
@ -505,7 +507,7 @@ def gmixer_12_224(pretrained=False, **kwargs):
|
||||
@register_model
|
||||
def gmixer_24_224(pretrained=False, **kwargs):
|
||||
""" Glu-Mixer-24 224x224
|
||||
Experiment by Ross Wightman, adding (Si)GLU to MLP-Mixer
|
||||
Experiment by Ross Wightman, adding SwiGLU to MLP-Mixer
|
||||
"""
|
||||
model_args = dict(
|
||||
patch_size=16, num_blocks=24, embed_dim=384, mlp_ratio=(1.0, 4.0),
|
||||
@ -561,92 +563,6 @@ def resmlp_big_24_224(pretrained=False, **kwargs):
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def resmlp_12_distilled_224(pretrained=False, **kwargs):
|
||||
""" ResMLP-12
|
||||
Paper: `ResMLP: Feedforward networks for image classification...` - https://arxiv.org/abs/2105.03404
|
||||
"""
|
||||
model_args = dict(
|
||||
patch_size=16, num_blocks=12, embed_dim=384, mlp_ratio=4, block_layer=ResBlock, norm_layer=Affine, **kwargs)
|
||||
model = _create_mixer('resmlp_12_distilled_224', pretrained=pretrained, **model_args)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def resmlp_24_distilled_224(pretrained=False, **kwargs):
|
||||
""" ResMLP-24
|
||||
Paper: `ResMLP: Feedforward networks for image classification...` - https://arxiv.org/abs/2105.03404
|
||||
"""
|
||||
model_args = dict(
|
||||
patch_size=16, num_blocks=24, embed_dim=384, mlp_ratio=4,
|
||||
block_layer=partial(ResBlock, init_values=1e-5), norm_layer=Affine, **kwargs)
|
||||
model = _create_mixer('resmlp_24_distilled_224', pretrained=pretrained, **model_args)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def resmlp_36_distilled_224(pretrained=False, **kwargs):
|
||||
""" ResMLP-36
|
||||
Paper: `ResMLP: Feedforward networks for image classification...` - https://arxiv.org/abs/2105.03404
|
||||
"""
|
||||
model_args = dict(
|
||||
patch_size=16, num_blocks=36, embed_dim=384, mlp_ratio=4,
|
||||
block_layer=partial(ResBlock, init_values=1e-6), norm_layer=Affine, **kwargs)
|
||||
model = _create_mixer('resmlp_36_distilled_224', pretrained=pretrained, **model_args)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def resmlp_big_24_distilled_224(pretrained=False, **kwargs):
|
||||
""" ResMLP-B-24
|
||||
Paper: `ResMLP: Feedforward networks for image classification...` - https://arxiv.org/abs/2105.03404
|
||||
"""
|
||||
model_args = dict(
|
||||
patch_size=8, num_blocks=24, embed_dim=768, mlp_ratio=4,
|
||||
block_layer=partial(ResBlock, init_values=1e-6), norm_layer=Affine, **kwargs)
|
||||
model = _create_mixer('resmlp_big_24_distilled_224', pretrained=pretrained, **model_args)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def resmlp_big_24_224_in22ft1k(pretrained=False, **kwargs):
|
||||
""" ResMLP-B-24
|
||||
Paper: `ResMLP: Feedforward networks for image classification...` - https://arxiv.org/abs/2105.03404
|
||||
"""
|
||||
model_args = dict(
|
||||
patch_size=8, num_blocks=24, embed_dim=768, mlp_ratio=4,
|
||||
block_layer=partial(ResBlock, init_values=1e-6), norm_layer=Affine, **kwargs)
|
||||
model = _create_mixer('resmlp_big_24_224_in22ft1k', pretrained=pretrained, **model_args)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def resmlp_12_224_dino(pretrained=False, **kwargs):
|
||||
""" ResMLP-12
|
||||
Paper: `ResMLP: Feedforward networks for image classification...` - https://arxiv.org/abs/2105.03404
|
||||
|
||||
Model pretrained via DINO (self-supervised) - https://arxiv.org/abs/2104.14294
|
||||
"""
|
||||
model_args = dict(
|
||||
patch_size=16, num_blocks=12, embed_dim=384, mlp_ratio=4, block_layer=ResBlock, norm_layer=Affine, **kwargs)
|
||||
model = _create_mixer('resmlp_12_224_dino', pretrained=pretrained, **model_args)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def resmlp_24_224_dino(pretrained=False, **kwargs):
|
||||
""" ResMLP-24
|
||||
Paper: `ResMLP: Feedforward networks for image classification...` - https://arxiv.org/abs/2105.03404
|
||||
|
||||
Model pretrained via DINO (self-supervised) - https://arxiv.org/abs/2104.14294
|
||||
"""
|
||||
model_args = dict(
|
||||
patch_size=16, num_blocks=24, embed_dim=384, mlp_ratio=4,
|
||||
block_layer=partial(ResBlock, init_values=1e-5), norm_layer=Affine, **kwargs)
|
||||
model = _create_mixer('resmlp_24_224_dino', pretrained=pretrained, **model_args)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def gmlp_ti16_224(pretrained=False, **kwargs):
|
||||
""" gMLP-Tiny
|
||||
@ -681,3 +597,18 @@ def gmlp_b16_224(pretrained=False, **kwargs):
|
||||
mlp_layer=GatedMlp, **kwargs)
|
||||
model = _create_mixer('gmlp_b16_224', pretrained=pretrained, **model_args)
|
||||
return model
|
||||
|
||||
|
||||
register_model_deprecations(__name__, {
|
||||
'mixer_b16_224_in21k': 'mixer_b16_224.goog_in21k_ft_in1k',
|
||||
'mixer_l16_224_in21k': 'mixer_l16_224.goog_in21k_ft_in1k',
|
||||
'mixer_b16_224_miil': 'mixer_b16_224.miil_in21k_ft_in1k',
|
||||
'mixer_b16_224_miil_in21k': 'mixer_b16_224.miil_in21k',
|
||||
'resmlp_12_distilled_224': 'resmlp_12_224.fb_distilled_in1k',
|
||||
'resmlp_24_distilled_224': 'resmlp_24_224.fb_distilled_in1k',
|
||||
'resmlp_36_distilled_224': 'resmlp_36_224.fb_distilled_in1k',
|
||||
'resmlp_big_24_distilled_224': 'resmlp_big_24_224.fb_distilled_in1k',
|
||||
'resmlp_big_24_224_in22ft1k': 'resmlp_big_24_224.fb_in22k_ft_in1k',
|
||||
'resmlp_12_224_dino': 'resmlp_12_224',
|
||||
'resmlp_24_224_dino': 'resmlp_24_224',
|
||||
})
|
||||
|
Loading…
x
Reference in New Issue
Block a user