From f2086f51a03fac2bffffd472ce805bb203920b7f Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Fri, 23 Aug 2024 10:39:01 -0700 Subject: [PATCH] Add mambaout builder support, pretrained weight remap --- timm/models/mambaout.py | 131 +++++++++++++++++----------------------- 1 file changed, 55 insertions(+), 76 deletions(-) diff --git a/timm/models/mambaout.py b/timm/models/mambaout.py index 3acd1d6f..a57ba8f3 100644 --- a/timm/models/mambaout.py +++ b/timm/models/mambaout.py @@ -5,17 +5,16 @@ timm (https://github.com/rwightman/pytorch-image-models), MetaFormer (https://github.com/sail-sg/metaformer), InceptionNeXt (https://github.com/sail-sg/inceptionnext) """ -from functools import partial from typing import Optional import torch import torch.nn as nn -import torch.nn.functional as F -from timm.models.layers import trunc_normal_, DropPath, LayerNorm -from .vision_transformer import LayerScale -from ._manipulate import checkpoint_seq -from timm.models.registry import register_model + from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from timm.layers import trunc_normal_, DropPath, LayerNorm, LayerScale +from ._builder import build_model_with_cfg +from ._manipulate import checkpoint_seq +from ._registry import register_model class Stem(nn.Module): @@ -275,6 +274,7 @@ class MambaOut(nn.Module): act_layer=nn.GELU, conv_ratio=1.0, kernel_size=7, + stem_mid_norm=True, ls_init_value=None, drop_path_rate=0., drop_rate=0., @@ -293,7 +293,13 @@ class MambaOut(nn.Module): num_stage = len(depths) self.num_stage = num_stage - self.stem = Stem(in_chans, dims[0], act_layer=act_layer, norm_layer=norm_layer) + self.stem = Stem( + in_chans, + dims[0], + mid_norm=stem_mid_norm, + act_layer=act_layer, + norm_layer=norm_layer, + ) prev_dim = dims[0] dp_rates = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)] self.stages = nn.ModuleList() @@ -338,7 +344,7 @@ class MambaOut(nn.Module): x = s(x) return x - def forward_head(self, x): + def forward_head(self, x, pre_logits: bool = False): x = x.mean((1, 2)) x = self.norm(x) x = self.head(x) @@ -350,6 +356,21 @@ class MambaOut(nn.Module): return x +def checkpoint_filter_fn(state_dict, model): + if 'model' in state_dict: + state_dict = state_dict['model'] + + import re + out_dict = {} + for k, v in state_dict.items(): + k = k.replace('downsample_layers.0.', 'stem.') + k = re.sub(r'stages.([0-9]+).([0-9]+)', r'stages.\1.blocks.\2', k) + k = re.sub(r'downsample_layers.([0-9]+)', r'stages.\1.downsample', k) + out_dict[k] = v + + return out_dict + + def _cfg(url='', **kwargs): return { 'url': url, @@ -376,105 +397,63 @@ default_cfgs = { } +def _create_mambaout(variant, pretrained=False, **kwargs): + model = build_model_with_cfg( + MambaOut, variant, pretrained, + pretrained_filter_fn=checkpoint_filter_fn, + feature_cfg=dict(out_indices=(0, 1, 2, 3), flatten_sequential=True), + **kwargs, + ) + return model + + # a series of MambaOut models @register_model def mambaout_femto(pretrained=False, **kwargs): - model = MambaOut( - depths=[3, 3, 9, 3], - dims=[48, 96, 192, 288], - **kwargs) - model.default_cfg = default_cfgs['mambaout_femto'] - if pretrained: - state_dict = torch.hub.load_state_dict_from_url( - url=model.default_cfg['url'], map_location="cpu", check_hash=True) - model.load_state_dict(state_dict) - return model - + model_args = dict(depths=(3, 3, 9, 3), dims=(48, 96, 192, 288)) + return _create_mambaout('mambaout_femto', pretrained=pretrained, **dict(model_args, **kwargs)) # Kobe Memorial Version with 24 Gated CNN blocks @register_model def mambaout_kobe(pretrained=False, **kwargs): - model = MambaOut( - depths=[3, 3, 15, 3], - dims=[48, 96, 192, 288], - **kwargs) - model.default_cfg = default_cfgs['mambaout_kobe'] - if pretrained: - state_dict = torch.hub.load_state_dict_from_url( - url=model.default_cfg['url'], map_location="cpu", check_hash=True) - model.load_state_dict(state_dict) - return model - + model_args = dict(depths=[3, 3, 15, 3], dims=[48, 96, 192, 288]) + return _create_mambaout('mambaout_kobe', pretrained=pretrained, **dict(model_args, **kwargs)) @register_model def mambaout_tiny(pretrained=False, **kwargs): - model = MambaOut( - depths=[3, 3, 9, 3], - dims=[96, 192, 384, 576], - **kwargs) - model.default_cfg = default_cfgs['mambaout_tiny'] - if pretrained: - state_dict = torch.hub.load_state_dict_from_url( - url=model.default_cfg['url'], map_location="cpu", check_hash=True) - model.load_state_dict(state_dict) - return model + model_args = dict(depths=[3, 3, 9, 3], dims=[96, 192, 384, 576]) + return _create_mambaout('mambaout_tiny', pretrained=pretrained, **dict(model_args, **kwargs)) @register_model def mambaout_small(pretrained=False, **kwargs): - model = MambaOut( - depths=[3, 4, 27, 3], - dims=[96, 192, 384, 576], - **kwargs) - model.default_cfg = default_cfgs['mambaout_small'] - if pretrained: - state_dict = torch.hub.load_state_dict_from_url( - url=model.default_cfg['url'], map_location="cpu", check_hash=True) - model.load_state_dict(state_dict) - return model + model_args = dict(depths=[3, 4, 27, 3], dims=[96, 192, 384, 576]) + return _create_mambaout('mambaout_small', pretrained=pretrained, **dict(model_args, **kwargs)) @register_model def mambaout_base(pretrained=False, **kwargs): - model = MambaOut( - depths=[3, 4, 27, 3], - dims=[128, 256, 512, 768], - **kwargs) - model.default_cfg = default_cfgs['mambaout_base'] - if pretrained: - state_dict = torch.hub.load_state_dict_from_url( - url=model.default_cfg['url'], map_location="cpu", check_hash=True) - model.load_state_dict(state_dict) - return model + model_args = dict(depths=[3, 4, 27, 3], dims=[128, 256, 512, 768]) + return _create_mambaout('mambaout_base', pretrained=pretrained, **dict(model_args, **kwargs)) @register_model def mambaout_small_rw(pretrained=False, **kwargs): - model = MambaOut( + model_args = dict( depths=[3, 4, 27, 3], dims=[96, 192, 384, 576], + stem_mid_norm=False, ls_init_value=1e-6, - **kwargs, ) - model.default_cfg = default_cfgs['mambaout_small'] - if pretrained: - state_dict = torch.hub.load_state_dict_from_url( - url=model.default_cfg['url'], map_location="cpu", check_hash=True) - model.load_state_dict(state_dict) - return model + return _create_mambaout('mambaout_small_rw', pretrained=pretrained, **dict(model_args, **kwargs)) @register_model def mambaout_base_rw(pretrained=False, **kwargs): - model = MambaOut( + model_args = dict( depths=(3, 4, 27, 3), dims=(128, 256, 512, 768), + stem_mid_norm=False, ls_init_value=1e-6, - **kwargs ) - model.default_cfg = default_cfgs['mambaout_base'] - if pretrained: - state_dict = torch.hub.load_state_dict_from_url( - url=model.default_cfg['url'], map_location="cpu", check_hash=True) - model.load_state_dict(state_dict) - return model + return _create_mambaout('mambaout_base_rw', pretrained=pretrained, **dict(model_args, **kwargs))