mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Add official pretrained weights to MLP-Mixer, complete model cfgs.
This commit is contained in:
parent
12efffa6b1
commit
2d8b09fe8b
@ -15,7 +15,7 @@ if hasattr(torch._C, '_jit_set_profiling_executor'):
|
||||
torch._C._jit_set_profiling_mode(False)
|
||||
|
||||
# transformer models don't support many of the spatial / feature based model functionalities
|
||||
NON_STD_FILTERS = ['vit_*', 'tnt_*', 'pit_*', 'swin_*', 'coat_*']
|
||||
NON_STD_FILTERS = ['vit_*', 'tnt_*', 'pit_*', 'swin_*', 'coat_*', 'mixer_*']
|
||||
NUM_NON_STD = len(NON_STD_FILTERS)
|
||||
|
||||
# exclude models that cause specific test failures
|
||||
|
@ -1,10 +1,23 @@
|
||||
""" MLP-Mixer in PyTorch
|
||||
|
||||
Official JAX impl: https://github.com/google-research/vision_transformer/blob/linen/vit_jax/models_mixer.py
|
||||
|
||||
Paper: 'MLP-Mixer: An all-MLP Architecture for Vision' - https://arxiv.org/abs/2105.01601
|
||||
|
||||
NOTE this is a very early stage first run through, the param counts aren't matching paper so
|
||||
something is up...
|
||||
@article{tolstikhin2021,
|
||||
title={MLP-Mixer: An all-MLP Architecture for Vision},
|
||||
author={Tolstikhin, Ilya and Houlsby, Neil and Kolesnikov, Alexander and Beyer, Lucas and Zhai, Xiaohua and Unterthiner,
|
||||
Thomas and Yung, Jessica and Keysers, Daniel and Uszkoreit, Jakob and Lucic, Mario and Dosovitskiy, Alexey},
|
||||
journal={arXiv preprint arXiv:2105.01601},
|
||||
year={2021}
|
||||
}
|
||||
|
||||
A thank you to paper authors for releasing code and weights.
|
||||
|
||||
Hacked together by / Copyright 2021 Ross Wightman
|
||||
"""
|
||||
import math
|
||||
from copy import deepcopy
|
||||
from functools import partial
|
||||
|
||||
import torch
|
||||
@ -12,7 +25,7 @@ import torch.nn as nn
|
||||
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from .helpers import build_model_with_cfg, overlay_external_default_cfg
|
||||
from .layers import DropPath, to_2tuple, trunc_normal_, lecun_normal_
|
||||
from .layers import DropPath, to_2tuple, lecun_normal_
|
||||
from .registry import register_model
|
||||
|
||||
|
||||
@ -20,14 +33,39 @@ def _cfg(url='', **kwargs):
|
||||
return {
|
||||
'url': url,
|
||||
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
|
||||
'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True,
|
||||
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
|
||||
'first_conv': 'patch_embed.proj', 'classifier': 'head',
|
||||
'crop_pct': 0.875, 'interpolation': 'bicubic', 'fixed_input_size': True,
|
||||
'mean': (0.5, 0.5, 0.5), 'std': (0.5, 0.5, 0.5),
|
||||
'first_conv': 'stem.proj', 'classifier': 'head',
|
||||
**kwargs
|
||||
}
|
||||
|
||||
|
||||
default_cfgs = dict(
|
||||
mixer_s32_224=_cfg(),
|
||||
mixer_s16_224=_cfg(),
|
||||
mixer_b32_224=_cfg(),
|
||||
mixer_b16_224=_cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_mixer_b16_224-76587d61.pth',
|
||||
),
|
||||
mixer_b16_224_in21k=_cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_mixer_b16_224_in21k-617b3de2.pth',
|
||||
num_classes=21843
|
||||
),
|
||||
mixer_l32_224=_cfg(),
|
||||
mixer_l16_224=_cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_mixer_l16_224-92f9adc4.pth',
|
||||
),
|
||||
mixer_l16_224_in21k=_cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_mixer_l16_224_in21k-846aa33c.pth',
|
||||
num_classes=21843
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class Mlp(nn.Module):
|
||||
""" MLP Block
|
||||
NOTE: same impl as ViT, move to common location
|
||||
"""
|
||||
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
|
||||
super().__init__()
|
||||
out_features = out_features or in_features
|
||||
@ -48,6 +86,7 @@ class Mlp(nn.Module):
|
||||
|
||||
class PatchEmbed(nn.Module):
|
||||
""" Image to Patch Embedding
|
||||
NOTE: same impl as ViT, move to common location
|
||||
"""
|
||||
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, norm_layer=None):
|
||||
super().__init__()
|
||||
@ -78,13 +117,13 @@ class MixerBlock(nn.Module):
|
||||
norm_layer=partial(nn.LayerNorm, eps=1e-6), act_layer=nn.GELU, drop=0., drop_path=0.):
|
||||
super().__init__()
|
||||
self.norm1 = norm_layer(dim)
|
||||
self.mlp_token = Mlp(seq_len, tokens_dim, act_layer=act_layer, drop=drop)
|
||||
self.mlp_tokens = Mlp(seq_len, tokens_dim, act_layer=act_layer, drop=drop)
|
||||
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
||||
self.norm2 = norm_layer(dim)
|
||||
self.mlp_channels = Mlp(dim, channels_dim, act_layer=act_layer, drop=drop)
|
||||
|
||||
def forward(self, x):
|
||||
x = x + self.drop_path(self.mlp_token(self.norm1(x).transpose(1, 2)).transpose(1, 2))
|
||||
x = x + self.drop_path(self.mlp_tokens(self.norm1(x).transpose(1, 2)).transpose(1, 2))
|
||||
x = x + self.drop_path(self.mlp_channels(self.norm2(x)))
|
||||
return x
|
||||
|
||||
@ -105,6 +144,7 @@ class MlpMixer(nn.Module):
|
||||
act_layer=nn.GELU,
|
||||
drop=0.,
|
||||
drop_path=0.,
|
||||
nlhb=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.num_classes = num_classes
|
||||
@ -116,9 +156,16 @@ class MlpMixer(nn.Module):
|
||||
hidden_dim, self.stem.num_patches, tokens_dim, channels_dim,
|
||||
norm_layer=norm_layer, act_layer=act_layer, drop=drop, drop_path=drop_path)
|
||||
for _ in range(num_blocks)])
|
||||
self.norm = nn.LayerNorm(hidden_dim)
|
||||
self.norm = norm_layer(hidden_dim)
|
||||
self.head = nn.Linear(hidden_dim, self.num_classes) # zero init
|
||||
|
||||
self.init_weights(nlhb=nlhb)
|
||||
|
||||
def init_weights(self, nlhb=False):
|
||||
head_bias = -math.log(self.num_classes) if nlhb else 0.
|
||||
for n, m in self.named_modules():
|
||||
_init_weights(m, n, head_bias=head_bias)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.stem(x)
|
||||
x = self.blocks(x)
|
||||
@ -128,15 +175,118 @@ class MlpMixer(nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
@register_model
|
||||
def mixer_small_p16(pretrained=False, **kwargs):
|
||||
model = MlpMixer()
|
||||
model.default_cfg = _cfg()
|
||||
def _init_weights(m, n: str, head_bias: float = 0.):
|
||||
""" Mixer weight initialization (trying to match Flax defaults)
|
||||
"""
|
||||
if isinstance(m, nn.Linear):
|
||||
if n.startswith('head'):
|
||||
nn.init.zeros_(m.weight)
|
||||
nn.init.constant_(m.bias, head_bias)
|
||||
else:
|
||||
nn.init.xavier_uniform_(m.weight)
|
||||
if m.bias is not None:
|
||||
if 'mlp' in n:
|
||||
nn.init.normal_(m.bias, std=1e-6)
|
||||
else:
|
||||
nn.init.zeros_(m.bias)
|
||||
elif isinstance(m, nn.Conv2d):
|
||||
lecun_normal_(m.weight)
|
||||
if m.bias is not None:
|
||||
nn.init.zeros_(m.bias)
|
||||
elif isinstance(m, nn.LayerNorm):
|
||||
nn.init.zeros_(m.bias)
|
||||
nn.init.ones_(m.weight)
|
||||
|
||||
|
||||
def _create_mixer(variant, pretrained=False, default_cfg=None, **kwargs):
|
||||
if default_cfg is None:
|
||||
default_cfg = deepcopy(default_cfgs[variant])
|
||||
overlay_external_default_cfg(default_cfg, kwargs)
|
||||
default_num_classes = default_cfg['num_classes']
|
||||
default_img_size = default_cfg['input_size'][-2:]
|
||||
num_classes = kwargs.pop('num_classes', default_num_classes)
|
||||
img_size = kwargs.pop('img_size', default_img_size)
|
||||
|
||||
if kwargs.get('features_only', None):
|
||||
raise RuntimeError('features_only not implemented for MLP-Mixer models.')
|
||||
|
||||
model = build_model_with_cfg(
|
||||
MlpMixer, variant, pretrained,
|
||||
default_cfg=default_cfg,
|
||||
img_size=img_size,
|
||||
num_classes=num_classes,
|
||||
**kwargs)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def mixer_base_p16(pretrained=False, **kwargs):
|
||||
model = MlpMixer(num_blocks=12, hidden_dim=768, tokens_dim=384, channels_dim=3072)
|
||||
model.default_cfg = _cfg()
|
||||
return model
|
||||
def mixer_s32_224(pretrained=False, **kwargs):
|
||||
""" Mixer-S/32 224x224
|
||||
"""
|
||||
model_args = dict(patch_size=32, num_blocks=8, hidden_dim=512, tokens_dim=256, channels_dim=2048, **kwargs)
|
||||
model = _create_mixer('mixer_s32_224', pretrained=pretrained, **model_args)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def mixer_s16_224(pretrained=False, **kwargs):
|
||||
""" Mixer-S/16 224x224
|
||||
"""
|
||||
model_args = dict(patch_size=16, num_blocks=8, hidden_dim=512, tokens_dim=256, channels_dim=2048, **kwargs)
|
||||
model = _create_mixer('mixer_s16_224', pretrained=pretrained, **model_args)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def mixer_b32_224(pretrained=False, **kwargs):
|
||||
""" Mixer-B/32 224x224
|
||||
"""
|
||||
model_args = dict(patch_size=32, num_blocks=12, hidden_dim=768, tokens_dim=384, channels_dim=3072, **kwargs)
|
||||
model = _create_mixer('mixer_b32_224', pretrained=pretrained, **model_args)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def mixer_b16_224(pretrained=False, **kwargs):
|
||||
""" Mixer-B/16 224x224. ImageNet-1k pretrained weights.
|
||||
"""
|
||||
model_args = dict(patch_size=16, num_blocks=12, hidden_dim=768, tokens_dim=384, channels_dim=3072, **kwargs)
|
||||
model = _create_mixer('mixer_b16_224', pretrained=pretrained, **model_args)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def mixer_b16_224_in21k(pretrained=False, **kwargs):
|
||||
""" Mixer-B/16 224x224. ImageNet-21k pretrained weights.
|
||||
"""
|
||||
model_args = dict(patch_size=16, num_blocks=12, hidden_dim=768, tokens_dim=384, channels_dim=3072, **kwargs)
|
||||
model = _create_mixer('mixer_b16_224_in21k', pretrained=pretrained, **model_args)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def mixer_l32_224(pretrained=False, **kwargs):
|
||||
""" Mixer-L/32 224x224.
|
||||
"""
|
||||
model_args = dict(patch_size=32, num_blocks=24, hidden_dim=1024, tokens_dim=512, channels_dim=4096, **kwargs)
|
||||
model = _create_mixer('mixer_l32_224', pretrained=pretrained, **model_args)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def mixer_l16_224(pretrained=False, **kwargs):
|
||||
""" Mixer-L/16 224x224. ImageNet-1k pretrained weights.
|
||||
"""
|
||||
model_args = dict(patch_size=16, num_blocks=24, hidden_dim=1024, tokens_dim=512, channels_dim=4096, **kwargs)
|
||||
model = _create_mixer('mixer_l16_224', pretrained=pretrained, **model_args)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def mixer_l16_224_in21k(pretrained=False, **kwargs):
|
||||
""" Mixer-L/16 224x224. ImageNet-21k pretrained weights.
|
||||
"""
|
||||
model_args = dict(patch_size=16, num_blocks=24, hidden_dim=1024, tokens_dim=512, channels_dim=4096, **kwargs)
|
||||
model = _create_mixer('mixer_l16_224_in21k', pretrained=pretrained, **model_args)
|
||||
return model
|
||||
|
Loading…
x
Reference in New Issue
Block a user