Add mambaout builder support, pretrained weight remap

This commit is contained in:
Ross Wightman 2024-08-23 10:39:01 -07:00
parent c6ef54eefa
commit f2086f51a0

View File

@ -5,17 +5,16 @@ timm (https://github.com/rwightman/pytorch-image-models),
MetaFormer (https://github.com/sail-sg/metaformer),
InceptionNeXt (https://github.com/sail-sg/inceptionnext)
"""
from functools import partial
from typing import Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from timm.models.layers import trunc_normal_, DropPath, LayerNorm
from .vision_transformer import LayerScale
from ._manipulate import checkpoint_seq
from timm.models.registry import register_model
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import trunc_normal_, DropPath, LayerNorm, LayerScale
from ._builder import build_model_with_cfg
from ._manipulate import checkpoint_seq
from ._registry import register_model
class Stem(nn.Module):
@ -275,6 +274,7 @@ class MambaOut(nn.Module):
act_layer=nn.GELU,
conv_ratio=1.0,
kernel_size=7,
stem_mid_norm=True,
ls_init_value=None,
drop_path_rate=0.,
drop_rate=0.,
@ -293,7 +293,13 @@ class MambaOut(nn.Module):
num_stage = len(depths)
self.num_stage = num_stage
self.stem = Stem(in_chans, dims[0], act_layer=act_layer, norm_layer=norm_layer)
self.stem = Stem(
in_chans,
dims[0],
mid_norm=stem_mid_norm,
act_layer=act_layer,
norm_layer=norm_layer,
)
prev_dim = dims[0]
dp_rates = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)]
self.stages = nn.ModuleList()
@ -338,7 +344,7 @@ class MambaOut(nn.Module):
x = s(x)
return x
def forward_head(self, x):
def forward_head(self, x, pre_logits: bool = False):
x = x.mean((1, 2))
x = self.norm(x)
x = self.head(x)
@ -350,6 +356,21 @@ class MambaOut(nn.Module):
return x
def checkpoint_filter_fn(state_dict, model):
if 'model' in state_dict:
state_dict = state_dict['model']
import re
out_dict = {}
for k, v in state_dict.items():
k = k.replace('downsample_layers.0.', 'stem.')
k = re.sub(r'stages.([0-9]+).([0-9]+)', r'stages.\1.blocks.\2', k)
k = re.sub(r'downsample_layers.([0-9]+)', r'stages.\1.downsample', k)
out_dict[k] = v
return out_dict
def _cfg(url='', **kwargs):
return {
'url': url,
@ -376,105 +397,63 @@ default_cfgs = {
}
def _create_mambaout(variant, pretrained=False, **kwargs):
model = build_model_with_cfg(
MambaOut, variant, pretrained,
pretrained_filter_fn=checkpoint_filter_fn,
feature_cfg=dict(out_indices=(0, 1, 2, 3), flatten_sequential=True),
**kwargs,
)
return model
# a series of MambaOut models
@register_model
def mambaout_femto(pretrained=False, **kwargs):
model = MambaOut(
depths=[3, 3, 9, 3],
dims=[48, 96, 192, 288],
**kwargs)
model.default_cfg = default_cfgs['mambaout_femto']
if pretrained:
state_dict = torch.hub.load_state_dict_from_url(
url=model.default_cfg['url'], map_location="cpu", check_hash=True)
model.load_state_dict(state_dict)
return model
model_args = dict(depths=(3, 3, 9, 3), dims=(48, 96, 192, 288))
return _create_mambaout('mambaout_femto', pretrained=pretrained, **dict(model_args, **kwargs))
# Kobe Memorial Version with 24 Gated CNN blocks
@register_model
def mambaout_kobe(pretrained=False, **kwargs):
model = MambaOut(
depths=[3, 3, 15, 3],
dims=[48, 96, 192, 288],
**kwargs)
model.default_cfg = default_cfgs['mambaout_kobe']
if pretrained:
state_dict = torch.hub.load_state_dict_from_url(
url=model.default_cfg['url'], map_location="cpu", check_hash=True)
model.load_state_dict(state_dict)
return model
model_args = dict(depths=[3, 3, 15, 3], dims=[48, 96, 192, 288])
return _create_mambaout('mambaout_kobe', pretrained=pretrained, **dict(model_args, **kwargs))
@register_model
def mambaout_tiny(pretrained=False, **kwargs):
model = MambaOut(
depths=[3, 3, 9, 3],
dims=[96, 192, 384, 576],
**kwargs)
model.default_cfg = default_cfgs['mambaout_tiny']
if pretrained:
state_dict = torch.hub.load_state_dict_from_url(
url=model.default_cfg['url'], map_location="cpu", check_hash=True)
model.load_state_dict(state_dict)
return model
model_args = dict(depths=[3, 3, 9, 3], dims=[96, 192, 384, 576])
return _create_mambaout('mambaout_tiny', pretrained=pretrained, **dict(model_args, **kwargs))
@register_model
def mambaout_small(pretrained=False, **kwargs):
model = MambaOut(
depths=[3, 4, 27, 3],
dims=[96, 192, 384, 576],
**kwargs)
model.default_cfg = default_cfgs['mambaout_small']
if pretrained:
state_dict = torch.hub.load_state_dict_from_url(
url=model.default_cfg['url'], map_location="cpu", check_hash=True)
model.load_state_dict(state_dict)
return model
model_args = dict(depths=[3, 4, 27, 3], dims=[96, 192, 384, 576])
return _create_mambaout('mambaout_small', pretrained=pretrained, **dict(model_args, **kwargs))
@register_model
def mambaout_base(pretrained=False, **kwargs):
model = MambaOut(
depths=[3, 4, 27, 3],
dims=[128, 256, 512, 768],
**kwargs)
model.default_cfg = default_cfgs['mambaout_base']
if pretrained:
state_dict = torch.hub.load_state_dict_from_url(
url=model.default_cfg['url'], map_location="cpu", check_hash=True)
model.load_state_dict(state_dict)
return model
model_args = dict(depths=[3, 4, 27, 3], dims=[128, 256, 512, 768])
return _create_mambaout('mambaout_base', pretrained=pretrained, **dict(model_args, **kwargs))
@register_model
def mambaout_small_rw(pretrained=False, **kwargs):
model = MambaOut(
model_args = dict(
depths=[3, 4, 27, 3],
dims=[96, 192, 384, 576],
stem_mid_norm=False,
ls_init_value=1e-6,
**kwargs,
)
model.default_cfg = default_cfgs['mambaout_small']
if pretrained:
state_dict = torch.hub.load_state_dict_from_url(
url=model.default_cfg['url'], map_location="cpu", check_hash=True)
model.load_state_dict(state_dict)
return model
return _create_mambaout('mambaout_small_rw', pretrained=pretrained, **dict(model_args, **kwargs))
@register_model
def mambaout_base_rw(pretrained=False, **kwargs):
model = MambaOut(
model_args = dict(
depths=(3, 4, 27, 3),
dims=(128, 256, 512, 768),
stem_mid_norm=False,
ls_init_value=1e-6,
**kwargs
)
model.default_cfg = default_cfgs['mambaout_base']
if pretrained:
state_dict = torch.hub.load_state_dict_from_url(
url=model.default_cfg['url'], map_location="cpu", check_hash=True)
model.load_state_dict(state_dict)
return model
return _create_mambaout('mambaout_base_rw', pretrained=pretrained, **dict(model_args, **kwargs))