mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Add mambaout builder support, pretrained weight remap
This commit is contained in:
parent
c6ef54eefa
commit
f2086f51a0
@ -5,17 +5,16 @@ timm (https://github.com/rwightman/pytorch-image-models),
|
||||
MetaFormer (https://github.com/sail-sg/metaformer),
|
||||
InceptionNeXt (https://github.com/sail-sg/inceptionnext)
|
||||
"""
|
||||
from functools import partial
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from timm.models.layers import trunc_normal_, DropPath, LayerNorm
|
||||
from .vision_transformer import LayerScale
|
||||
from ._manipulate import checkpoint_seq
|
||||
from timm.models.registry import register_model
|
||||
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from timm.layers import trunc_normal_, DropPath, LayerNorm, LayerScale
|
||||
from ._builder import build_model_with_cfg
|
||||
from ._manipulate import checkpoint_seq
|
||||
from ._registry import register_model
|
||||
|
||||
|
||||
class Stem(nn.Module):
|
||||
@ -275,6 +274,7 @@ class MambaOut(nn.Module):
|
||||
act_layer=nn.GELU,
|
||||
conv_ratio=1.0,
|
||||
kernel_size=7,
|
||||
stem_mid_norm=True,
|
||||
ls_init_value=None,
|
||||
drop_path_rate=0.,
|
||||
drop_rate=0.,
|
||||
@ -293,7 +293,13 @@ class MambaOut(nn.Module):
|
||||
num_stage = len(depths)
|
||||
self.num_stage = num_stage
|
||||
|
||||
self.stem = Stem(in_chans, dims[0], act_layer=act_layer, norm_layer=norm_layer)
|
||||
self.stem = Stem(
|
||||
in_chans,
|
||||
dims[0],
|
||||
mid_norm=stem_mid_norm,
|
||||
act_layer=act_layer,
|
||||
norm_layer=norm_layer,
|
||||
)
|
||||
prev_dim = dims[0]
|
||||
dp_rates = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)]
|
||||
self.stages = nn.ModuleList()
|
||||
@ -338,7 +344,7 @@ class MambaOut(nn.Module):
|
||||
x = s(x)
|
||||
return x
|
||||
|
||||
def forward_head(self, x):
|
||||
def forward_head(self, x, pre_logits: bool = False):
|
||||
x = x.mean((1, 2))
|
||||
x = self.norm(x)
|
||||
x = self.head(x)
|
||||
@ -350,6 +356,21 @@ class MambaOut(nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
def checkpoint_filter_fn(state_dict, model):
|
||||
if 'model' in state_dict:
|
||||
state_dict = state_dict['model']
|
||||
|
||||
import re
|
||||
out_dict = {}
|
||||
for k, v in state_dict.items():
|
||||
k = k.replace('downsample_layers.0.', 'stem.')
|
||||
k = re.sub(r'stages.([0-9]+).([0-9]+)', r'stages.\1.blocks.\2', k)
|
||||
k = re.sub(r'downsample_layers.([0-9]+)', r'stages.\1.downsample', k)
|
||||
out_dict[k] = v
|
||||
|
||||
return out_dict
|
||||
|
||||
|
||||
def _cfg(url='', **kwargs):
|
||||
return {
|
||||
'url': url,
|
||||
@ -376,105 +397,63 @@ default_cfgs = {
|
||||
}
|
||||
|
||||
|
||||
def _create_mambaout(variant, pretrained=False, **kwargs):
|
||||
model = build_model_with_cfg(
|
||||
MambaOut, variant, pretrained,
|
||||
pretrained_filter_fn=checkpoint_filter_fn,
|
||||
feature_cfg=dict(out_indices=(0, 1, 2, 3), flatten_sequential=True),
|
||||
**kwargs,
|
||||
)
|
||||
return model
|
||||
|
||||
|
||||
# a series of MambaOut models
|
||||
@register_model
|
||||
def mambaout_femto(pretrained=False, **kwargs):
|
||||
model = MambaOut(
|
||||
depths=[3, 3, 9, 3],
|
||||
dims=[48, 96, 192, 288],
|
||||
**kwargs)
|
||||
model.default_cfg = default_cfgs['mambaout_femto']
|
||||
if pretrained:
|
||||
state_dict = torch.hub.load_state_dict_from_url(
|
||||
url=model.default_cfg['url'], map_location="cpu", check_hash=True)
|
||||
model.load_state_dict(state_dict)
|
||||
return model
|
||||
|
||||
model_args = dict(depths=(3, 3, 9, 3), dims=(48, 96, 192, 288))
|
||||
return _create_mambaout('mambaout_femto', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
|
||||
# Kobe Memorial Version with 24 Gated CNN blocks
|
||||
@register_model
|
||||
def mambaout_kobe(pretrained=False, **kwargs):
|
||||
model = MambaOut(
|
||||
depths=[3, 3, 15, 3],
|
||||
dims=[48, 96, 192, 288],
|
||||
**kwargs)
|
||||
model.default_cfg = default_cfgs['mambaout_kobe']
|
||||
if pretrained:
|
||||
state_dict = torch.hub.load_state_dict_from_url(
|
||||
url=model.default_cfg['url'], map_location="cpu", check_hash=True)
|
||||
model.load_state_dict(state_dict)
|
||||
return model
|
||||
|
||||
model_args = dict(depths=[3, 3, 15, 3], dims=[48, 96, 192, 288])
|
||||
return _create_mambaout('mambaout_kobe', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
|
||||
@register_model
|
||||
def mambaout_tiny(pretrained=False, **kwargs):
|
||||
model = MambaOut(
|
||||
depths=[3, 3, 9, 3],
|
||||
dims=[96, 192, 384, 576],
|
||||
**kwargs)
|
||||
model.default_cfg = default_cfgs['mambaout_tiny']
|
||||
if pretrained:
|
||||
state_dict = torch.hub.load_state_dict_from_url(
|
||||
url=model.default_cfg['url'], map_location="cpu", check_hash=True)
|
||||
model.load_state_dict(state_dict)
|
||||
return model
|
||||
model_args = dict(depths=[3, 3, 9, 3], dims=[96, 192, 384, 576])
|
||||
return _create_mambaout('mambaout_tiny', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
|
||||
|
||||
@register_model
|
||||
def mambaout_small(pretrained=False, **kwargs):
|
||||
model = MambaOut(
|
||||
depths=[3, 4, 27, 3],
|
||||
dims=[96, 192, 384, 576],
|
||||
**kwargs)
|
||||
model.default_cfg = default_cfgs['mambaout_small']
|
||||
if pretrained:
|
||||
state_dict = torch.hub.load_state_dict_from_url(
|
||||
url=model.default_cfg['url'], map_location="cpu", check_hash=True)
|
||||
model.load_state_dict(state_dict)
|
||||
return model
|
||||
model_args = dict(depths=[3, 4, 27, 3], dims=[96, 192, 384, 576])
|
||||
return _create_mambaout('mambaout_small', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
|
||||
|
||||
@register_model
|
||||
def mambaout_base(pretrained=False, **kwargs):
|
||||
model = MambaOut(
|
||||
depths=[3, 4, 27, 3],
|
||||
dims=[128, 256, 512, 768],
|
||||
**kwargs)
|
||||
model.default_cfg = default_cfgs['mambaout_base']
|
||||
if pretrained:
|
||||
state_dict = torch.hub.load_state_dict_from_url(
|
||||
url=model.default_cfg['url'], map_location="cpu", check_hash=True)
|
||||
model.load_state_dict(state_dict)
|
||||
return model
|
||||
model_args = dict(depths=[3, 4, 27, 3], dims=[128, 256, 512, 768])
|
||||
return _create_mambaout('mambaout_base', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
|
||||
|
||||
@register_model
|
||||
def mambaout_small_rw(pretrained=False, **kwargs):
|
||||
model = MambaOut(
|
||||
model_args = dict(
|
||||
depths=[3, 4, 27, 3],
|
||||
dims=[96, 192, 384, 576],
|
||||
stem_mid_norm=False,
|
||||
ls_init_value=1e-6,
|
||||
**kwargs,
|
||||
)
|
||||
model.default_cfg = default_cfgs['mambaout_small']
|
||||
if pretrained:
|
||||
state_dict = torch.hub.load_state_dict_from_url(
|
||||
url=model.default_cfg['url'], map_location="cpu", check_hash=True)
|
||||
model.load_state_dict(state_dict)
|
||||
return model
|
||||
return _create_mambaout('mambaout_small_rw', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
|
||||
|
||||
@register_model
|
||||
def mambaout_base_rw(pretrained=False, **kwargs):
|
||||
model = MambaOut(
|
||||
model_args = dict(
|
||||
depths=(3, 4, 27, 3),
|
||||
dims=(128, 256, 512, 768),
|
||||
stem_mid_norm=False,
|
||||
ls_init_value=1e-6,
|
||||
**kwargs
|
||||
)
|
||||
model.default_cfg = default_cfgs['mambaout_base']
|
||||
if pretrained:
|
||||
state_dict = torch.hub.load_state_dict_from_url(
|
||||
url=model.default_cfg['url'], map_location="cpu", check_hash=True)
|
||||
model.load_state_dict(state_dict)
|
||||
return model
|
||||
return _create_mambaout('mambaout_base_rw', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
|
Loading…
x
Reference in New Issue
Block a user