mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
MambaOut weights on hub, configs finalized
This commit is contained in:
parent
7efb60c299
commit
82ae247879
@ -15,7 +15,7 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
|||||||
from timm.layers import trunc_normal_, DropPath, LayerNorm, LayerScale, ClNormMlpClassifierHead, get_act_layer
|
from timm.layers import trunc_normal_, DropPath, LayerNorm, LayerScale, ClNormMlpClassifierHead, get_act_layer
|
||||||
from ._builder import build_model_with_cfg
|
from ._builder import build_model_with_cfg
|
||||||
from ._manipulate import checkpoint_seq
|
from ._manipulate import checkpoint_seq
|
||||||
from ._registry import register_model
|
from ._registry import register_model, generate_default_cfgs
|
||||||
|
|
||||||
|
|
||||||
class Stem(nn.Module):
|
class Stem(nn.Module):
|
||||||
@ -435,6 +435,8 @@ class MambaOut(nn.Module):
|
|||||||
def checkpoint_filter_fn(state_dict, model):
|
def checkpoint_filter_fn(state_dict, model):
|
||||||
if 'model' in state_dict:
|
if 'model' in state_dict:
|
||||||
state_dict = state_dict['model']
|
state_dict = state_dict['model']
|
||||||
|
if 'stem.conv1.weight' in state_dict:
|
||||||
|
return state_dict
|
||||||
|
|
||||||
import re
|
import re
|
||||||
out_dict = {}
|
out_dict = {}
|
||||||
@ -458,30 +460,52 @@ def checkpoint_filter_fn(state_dict, model):
|
|||||||
def _cfg(url='', **kwargs):
|
def _cfg(url='', **kwargs):
|
||||||
return {
|
return {
|
||||||
'url': url,
|
'url': url,
|
||||||
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
|
'num_classes': 1000, 'input_size': (3, 224, 224), 'test_input_size': (3, 288, 288),
|
||||||
'crop_pct': 1.0, 'interpolation': 'bicubic',
|
'pool_size': (7, 7), 'crop_pct': 1.0, 'interpolation': 'bicubic',
|
||||||
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
|
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
|
||||||
'first_conv': 'stem.conv1', 'classifier': 'head.fc',
|
'first_conv': 'stem.conv1', 'classifier': 'head.fc',
|
||||||
**kwargs
|
**kwargs
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
default_cfgs = {
|
default_cfgs = generate_default_cfgs({
|
||||||
'mambaout_femto': _cfg(
|
# original weights
|
||||||
url='https://github.com/yuweihao/MambaOut/releases/download/model/mambaout_femto.pth'),
|
'mambaout_femto.in1k': _cfg(
|
||||||
'mambaout_kobe': _cfg(
|
hf_hub_id='timm/'),
|
||||||
url='https://github.com/yuweihao/MambaOut/releases/download/model/mambaout_kobe.pth'),
|
'mambaout_kobe.in1k': _cfg(
|
||||||
'mambaout_tiny': _cfg(
|
hf_hub_id='timm/'),
|
||||||
url='https://github.com/yuweihao/MambaOut/releases/download/model/mambaout_tiny.pth'),
|
'mambaout_tiny.in1k': _cfg(
|
||||||
'mambaout_small': _cfg(
|
hf_hub_id='timm/'),
|
||||||
url='https://github.com/yuweihao/MambaOut/releases/download/model/mambaout_small.pth'),
|
'mambaout_small.in1k': _cfg(
|
||||||
'mambaout_base': _cfg(
|
hf_hub_id='timm/'),
|
||||||
url='https://github.com/yuweihao/MambaOut/releases/download/model/mambaout_base.pth'),
|
'mambaout_base.in1k': _cfg(
|
||||||
'mambaout_small_rw': _cfg(),
|
hf_hub_id='timm/'),
|
||||||
'mambaout_base_slim_rw': _cfg(),
|
|
||||||
'mambaout_base_plus_rw': _cfg(),
|
# timm experiments below
|
||||||
'test_mambaout': _cfg(input_size=(3, 160, 160), pool_size=(5, 5)),
|
'mambaout_small_rw.sw_e450_in1k': _cfg(
|
||||||
}
|
hf_hub_id='timm/',
|
||||||
|
),
|
||||||
|
'mambaout_base_short_rw.sw_e500_in1k': _cfg(
|
||||||
|
hf_hub_id='timm/',
|
||||||
|
crop_pct=0.95, test_crop_pct=1.0,
|
||||||
|
),
|
||||||
|
'mambaout_base_tall_rw.sw_e500_in1k': _cfg(
|
||||||
|
hf_hub_id='timm/',
|
||||||
|
crop_pct=0.95, test_crop_pct=1.0,
|
||||||
|
),
|
||||||
|
'mambaout_base_wide_rw.sw_e500_in1k': _cfg(
|
||||||
|
hf_hub_id='timm/',
|
||||||
|
crop_pct=0.95, test_crop_pct=1.0,
|
||||||
|
),
|
||||||
|
'mambaout_base_plus_rw.sw_e150_in12k_ft_in1k': _cfg(
|
||||||
|
hf_hub_id='timm/',
|
||||||
|
),
|
||||||
|
'mambaout_base_plus_rw.sw_e150_in12k': _cfg(
|
||||||
|
hf_hub_id='timm/',
|
||||||
|
num_classes=11821,
|
||||||
|
),
|
||||||
|
'test_mambaout': _cfg(input_size=(3, 160, 160), test_input_size=(3, 192, 192), pool_size=(5, 5)),
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
def _create_mambaout(variant, pretrained=False, **kwargs):
|
def _create_mambaout(variant, pretrained=False, **kwargs):
|
||||||
@ -538,9 +562,24 @@ def mambaout_small_rw(pretrained=False, **kwargs):
|
|||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
def mambaout_base_slim_rw(pretrained=False, **kwargs):
|
def mambaout_base_short_rw(pretrained=False, **kwargs):
|
||||||
model_args = dict(
|
model_args = dict(
|
||||||
depths=(3, 4, 27, 3),
|
depths=(3, 3, 25, 3),
|
||||||
|
dims=(128, 256, 512, 768),
|
||||||
|
expansion_ratio=3.0,
|
||||||
|
conv_ratio=1.25,
|
||||||
|
stem_mid_norm=False,
|
||||||
|
downsample='conv_nf',
|
||||||
|
ls_init_value=1e-6,
|
||||||
|
head_fn='norm_mlp',
|
||||||
|
)
|
||||||
|
return _create_mambaout('mambaout_base_short_rw', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||||
|
|
||||||
|
|
||||||
|
@register_model
|
||||||
|
def mambaout_base_tall_rw(pretrained=False, **kwargs):
|
||||||
|
model_args = dict(
|
||||||
|
depths=(3, 4, 30, 3),
|
||||||
dims=(128, 256, 512, 768),
|
dims=(128, 256, 512, 768),
|
||||||
expansion_ratio=2.5,
|
expansion_ratio=2.5,
|
||||||
conv_ratio=1.25,
|
conv_ratio=1.25,
|
||||||
@ -549,13 +588,29 @@ def mambaout_base_slim_rw(pretrained=False, **kwargs):
|
|||||||
ls_init_value=1e-6,
|
ls_init_value=1e-6,
|
||||||
head_fn='norm_mlp',
|
head_fn='norm_mlp',
|
||||||
)
|
)
|
||||||
return _create_mambaout('mambaout_base_slim_rw', pretrained=pretrained, **dict(model_args, **kwargs))
|
return _create_mambaout('mambaout_base_tall_rw', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||||
|
|
||||||
|
|
||||||
|
@register_model
|
||||||
|
def mambaout_base_wide_rw(pretrained=False, **kwargs):
|
||||||
|
model_args = dict(
|
||||||
|
depths=(3, 4, 27, 3),
|
||||||
|
dims=(128, 256, 512, 768),
|
||||||
|
expansion_ratio=3.0,
|
||||||
|
conv_ratio=1.5,
|
||||||
|
stem_mid_norm=False,
|
||||||
|
downsample='conv_nf',
|
||||||
|
ls_init_value=1e-6,
|
||||||
|
act_layer='silu',
|
||||||
|
head_fn='norm_mlp',
|
||||||
|
)
|
||||||
|
return _create_mambaout('mambaout_base_wide_rw', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
def mambaout_base_plus_rw(pretrained=False, **kwargs):
|
def mambaout_base_plus_rw(pretrained=False, **kwargs):
|
||||||
model_args = dict(
|
model_args = dict(
|
||||||
depths=(3, 4, 27, 3),
|
depths=(3, 4, 30, 3),
|
||||||
dims=(128, 256, 512, 768),
|
dims=(128, 256, 512, 768),
|
||||||
expansion_ratio=3.0,
|
expansion_ratio=3.0,
|
||||||
conv_ratio=1.5,
|
conv_ratio=1.5,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user