mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Add mambaout builder support, pretrained weight remap
This commit is contained in:
parent
c6ef54eefa
commit
f2086f51a0
@ -5,17 +5,16 @@ timm (https://github.com/rwightman/pytorch-image-models),
|
|||||||
MetaFormer (https://github.com/sail-sg/metaformer),
|
MetaFormer (https://github.com/sail-sg/metaformer),
|
||||||
InceptionNeXt (https://github.com/sail-sg/inceptionnext)
|
InceptionNeXt (https://github.com/sail-sg/inceptionnext)
|
||||||
"""
|
"""
|
||||||
from functools import partial
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
|
||||||
from timm.models.layers import trunc_normal_, DropPath, LayerNorm
|
|
||||||
from .vision_transformer import LayerScale
|
|
||||||
from ._manipulate import checkpoint_seq
|
|
||||||
from timm.models.registry import register_model
|
|
||||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||||
|
from timm.layers import trunc_normal_, DropPath, LayerNorm, LayerScale
|
||||||
|
from ._builder import build_model_with_cfg
|
||||||
|
from ._manipulate import checkpoint_seq
|
||||||
|
from ._registry import register_model
|
||||||
|
|
||||||
|
|
||||||
class Stem(nn.Module):
|
class Stem(nn.Module):
|
||||||
@ -275,6 +274,7 @@ class MambaOut(nn.Module):
|
|||||||
act_layer=nn.GELU,
|
act_layer=nn.GELU,
|
||||||
conv_ratio=1.0,
|
conv_ratio=1.0,
|
||||||
kernel_size=7,
|
kernel_size=7,
|
||||||
|
stem_mid_norm=True,
|
||||||
ls_init_value=None,
|
ls_init_value=None,
|
||||||
drop_path_rate=0.,
|
drop_path_rate=0.,
|
||||||
drop_rate=0.,
|
drop_rate=0.,
|
||||||
@ -293,7 +293,13 @@ class MambaOut(nn.Module):
|
|||||||
num_stage = len(depths)
|
num_stage = len(depths)
|
||||||
self.num_stage = num_stage
|
self.num_stage = num_stage
|
||||||
|
|
||||||
self.stem = Stem(in_chans, dims[0], act_layer=act_layer, norm_layer=norm_layer)
|
self.stem = Stem(
|
||||||
|
in_chans,
|
||||||
|
dims[0],
|
||||||
|
mid_norm=stem_mid_norm,
|
||||||
|
act_layer=act_layer,
|
||||||
|
norm_layer=norm_layer,
|
||||||
|
)
|
||||||
prev_dim = dims[0]
|
prev_dim = dims[0]
|
||||||
dp_rates = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)]
|
dp_rates = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)]
|
||||||
self.stages = nn.ModuleList()
|
self.stages = nn.ModuleList()
|
||||||
@ -338,7 +344,7 @@ class MambaOut(nn.Module):
|
|||||||
x = s(x)
|
x = s(x)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
def forward_head(self, x):
|
def forward_head(self, x, pre_logits: bool = False):
|
||||||
x = x.mean((1, 2))
|
x = x.mean((1, 2))
|
||||||
x = self.norm(x)
|
x = self.norm(x)
|
||||||
x = self.head(x)
|
x = self.head(x)
|
||||||
@ -350,6 +356,21 @@ class MambaOut(nn.Module):
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def checkpoint_filter_fn(state_dict, model):
|
||||||
|
if 'model' in state_dict:
|
||||||
|
state_dict = state_dict['model']
|
||||||
|
|
||||||
|
import re
|
||||||
|
out_dict = {}
|
||||||
|
for k, v in state_dict.items():
|
||||||
|
k = k.replace('downsample_layers.0.', 'stem.')
|
||||||
|
k = re.sub(r'stages.([0-9]+).([0-9]+)', r'stages.\1.blocks.\2', k)
|
||||||
|
k = re.sub(r'downsample_layers.([0-9]+)', r'stages.\1.downsample', k)
|
||||||
|
out_dict[k] = v
|
||||||
|
|
||||||
|
return out_dict
|
||||||
|
|
||||||
|
|
||||||
def _cfg(url='', **kwargs):
|
def _cfg(url='', **kwargs):
|
||||||
return {
|
return {
|
||||||
'url': url,
|
'url': url,
|
||||||
@ -376,105 +397,63 @@ default_cfgs = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _create_mambaout(variant, pretrained=False, **kwargs):
|
||||||
|
model = build_model_with_cfg(
|
||||||
|
MambaOut, variant, pretrained,
|
||||||
|
pretrained_filter_fn=checkpoint_filter_fn,
|
||||||
|
feature_cfg=dict(out_indices=(0, 1, 2, 3), flatten_sequential=True),
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
# a series of MambaOut models
|
# a series of MambaOut models
|
||||||
@register_model
|
@register_model
|
||||||
def mambaout_femto(pretrained=False, **kwargs):
|
def mambaout_femto(pretrained=False, **kwargs):
|
||||||
model = MambaOut(
|
model_args = dict(depths=(3, 3, 9, 3), dims=(48, 96, 192, 288))
|
||||||
depths=[3, 3, 9, 3],
|
return _create_mambaout('mambaout_femto', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||||
dims=[48, 96, 192, 288],
|
|
||||||
**kwargs)
|
|
||||||
model.default_cfg = default_cfgs['mambaout_femto']
|
|
||||||
if pretrained:
|
|
||||||
state_dict = torch.hub.load_state_dict_from_url(
|
|
||||||
url=model.default_cfg['url'], map_location="cpu", check_hash=True)
|
|
||||||
model.load_state_dict(state_dict)
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
# Kobe Memorial Version with 24 Gated CNN blocks
|
# Kobe Memorial Version with 24 Gated CNN blocks
|
||||||
@register_model
|
@register_model
|
||||||
def mambaout_kobe(pretrained=False, **kwargs):
|
def mambaout_kobe(pretrained=False, **kwargs):
|
||||||
model = MambaOut(
|
model_args = dict(depths=[3, 3, 15, 3], dims=[48, 96, 192, 288])
|
||||||
depths=[3, 3, 15, 3],
|
return _create_mambaout('mambaout_kobe', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||||
dims=[48, 96, 192, 288],
|
|
||||||
**kwargs)
|
|
||||||
model.default_cfg = default_cfgs['mambaout_kobe']
|
|
||||||
if pretrained:
|
|
||||||
state_dict = torch.hub.load_state_dict_from_url(
|
|
||||||
url=model.default_cfg['url'], map_location="cpu", check_hash=True)
|
|
||||||
model.load_state_dict(state_dict)
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
def mambaout_tiny(pretrained=False, **kwargs):
|
def mambaout_tiny(pretrained=False, **kwargs):
|
||||||
model = MambaOut(
|
model_args = dict(depths=[3, 3, 9, 3], dims=[96, 192, 384, 576])
|
||||||
depths=[3, 3, 9, 3],
|
return _create_mambaout('mambaout_tiny', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||||
dims=[96, 192, 384, 576],
|
|
||||||
**kwargs)
|
|
||||||
model.default_cfg = default_cfgs['mambaout_tiny']
|
|
||||||
if pretrained:
|
|
||||||
state_dict = torch.hub.load_state_dict_from_url(
|
|
||||||
url=model.default_cfg['url'], map_location="cpu", check_hash=True)
|
|
||||||
model.load_state_dict(state_dict)
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
def mambaout_small(pretrained=False, **kwargs):
|
def mambaout_small(pretrained=False, **kwargs):
|
||||||
model = MambaOut(
|
model_args = dict(depths=[3, 4, 27, 3], dims=[96, 192, 384, 576])
|
||||||
depths=[3, 4, 27, 3],
|
return _create_mambaout('mambaout_small', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||||
dims=[96, 192, 384, 576],
|
|
||||||
**kwargs)
|
|
||||||
model.default_cfg = default_cfgs['mambaout_small']
|
|
||||||
if pretrained:
|
|
||||||
state_dict = torch.hub.load_state_dict_from_url(
|
|
||||||
url=model.default_cfg['url'], map_location="cpu", check_hash=True)
|
|
||||||
model.load_state_dict(state_dict)
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
def mambaout_base(pretrained=False, **kwargs):
|
def mambaout_base(pretrained=False, **kwargs):
|
||||||
model = MambaOut(
|
model_args = dict(depths=[3, 4, 27, 3], dims=[128, 256, 512, 768])
|
||||||
depths=[3, 4, 27, 3],
|
return _create_mambaout('mambaout_base', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||||
dims=[128, 256, 512, 768],
|
|
||||||
**kwargs)
|
|
||||||
model.default_cfg = default_cfgs['mambaout_base']
|
|
||||||
if pretrained:
|
|
||||||
state_dict = torch.hub.load_state_dict_from_url(
|
|
||||||
url=model.default_cfg['url'], map_location="cpu", check_hash=True)
|
|
||||||
model.load_state_dict(state_dict)
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
def mambaout_small_rw(pretrained=False, **kwargs):
|
def mambaout_small_rw(pretrained=False, **kwargs):
|
||||||
model = MambaOut(
|
model_args = dict(
|
||||||
depths=[3, 4, 27, 3],
|
depths=[3, 4, 27, 3],
|
||||||
dims=[96, 192, 384, 576],
|
dims=[96, 192, 384, 576],
|
||||||
|
stem_mid_norm=False,
|
||||||
ls_init_value=1e-6,
|
ls_init_value=1e-6,
|
||||||
**kwargs,
|
|
||||||
)
|
)
|
||||||
model.default_cfg = default_cfgs['mambaout_small']
|
return _create_mambaout('mambaout_small_rw', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||||
if pretrained:
|
|
||||||
state_dict = torch.hub.load_state_dict_from_url(
|
|
||||||
url=model.default_cfg['url'], map_location="cpu", check_hash=True)
|
|
||||||
model.load_state_dict(state_dict)
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
def mambaout_base_rw(pretrained=False, **kwargs):
|
def mambaout_base_rw(pretrained=False, **kwargs):
|
||||||
model = MambaOut(
|
model_args = dict(
|
||||||
depths=(3, 4, 27, 3),
|
depths=(3, 4, 27, 3),
|
||||||
dims=(128, 256, 512, 768),
|
dims=(128, 256, 512, 768),
|
||||||
|
stem_mid_norm=False,
|
||||||
ls_init_value=1e-6,
|
ls_init_value=1e-6,
|
||||||
**kwargs
|
|
||||||
)
|
)
|
||||||
model.default_cfg = default_cfgs['mambaout_base']
|
return _create_mambaout('mambaout_base_rw', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||||
if pretrained:
|
|
||||||
state_dict = torch.hub.load_state_dict_from_url(
|
|
||||||
url=model.default_cfg['url'], map_location="cpu", check_hash=True)
|
|
||||||
model.load_state_dict(state_dict)
|
|
||||||
return model
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user